CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_simt.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/matrix_shape.h"
35 #include "cutlass/gemm/gemm.h"
36 #include "cutlass/gemm/warp/mma.h"
37 
39 
42 
44 
45 namespace cutlass {
46 namespace gemm {
47 namespace warp {
48 
50 
52 template <
54  typename Shape_,
56  typename ElementA_,
58  typename LayoutA_,
60  typename ElementB_,
62  typename LayoutB_,
64  typename ElementC_,
66  typename LayoutC_,
68  typename Policy_,
70  int PartitionsK = 1,
72  typename Enable = bool
73 >
74 class MmaSimt {
75 public:
77  using Shape = Shape_;
78 
80  using ElementA = ElementA_;
81 
83  using LayoutA = LayoutA_;
84 
86  using ElementB = ElementB_;
87 
89  using LayoutB = LayoutB_;
90 
92  using ElementC = ElementC_;
93 
95  using LayoutC = LayoutC_;
96 
98  using Policy = Policy_;
99 
101  using OperatorClass = arch::OpClassSimt;
102 
107  LayoutA>::type
108  >::type;
109 
111  layout::ColumnMajor,
112  typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutB >::value,
113  layout::RowMajor,
114  LayoutB>::type
115  >::type;
116 
121 
123 
125  using ThreadMma = thread::Mma<
126  GemmShape<
127  Shape::kM / Policy::WarpShape::kRow,
128  Shape::kN / Policy::WarpShape::kColumn,
129  Policy::LaneMmaShape::kK>,
130  ElementA,
132  ElementB,
134  ElementC,
135  LayoutC,
136  arch::OpMultiplyAdd,
137  dp4a_type
138  >;
139 
140 public:
141 
145  Operand::kA,
146  ElementA,
147  LayoutA,
148  Policy,
149  PartitionsK,
150  Shape::kK
151  >;
152 
154  using FragmentA = typename IteratorA::Fragment;
155 
159  Operand::kB,
160  ElementB,
161  LayoutB,
162  Policy,
163  PartitionsK,
164  Shape::kK
165  >;
166 
168  using FragmentB = typename IteratorB::Fragment;
169 
173  Operand::kC,
174  ElementC,
175  LayoutC,
176  Policy
177  >;
178 
180  using FragmentC = typename ThreadMma::FragmentC;
181 
182 public:
183 
184  //
185  // Methods
186  //
187 
189  CUTLASS_DEVICE
190  MmaSimt() {}
191 
193  CUTLASS_DEVICE
195  FragmentC &d,
196  FragmentA const &a,
197  FragmentB const &b,
198  FragmentC const &c, int group_idx = 0) const {
199 
200  ThreadMma mma;
201 
202  mma(d, a, b, c);
203  }
204 };
205 
207 
208 } // namespace warp
209 } // namespace gemm
210 } // namespace cutlass
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_simt.h:92
Definition: aligned_buffer.h:35
#define constexpr
Definition: platform.h:137
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
T type
Definition: platform.h:326
std::is_same (false specialization)
Definition: platform.h:394
typename ThreadMma::FragmentC FragmentC
Storage for C tile.
Definition: mma_simt.h:180
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_simt.h:77
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_simt.h:74
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c, int group_idx=0) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_simt.h:194
static constexpr bool use_dp4a
Definition: mma_simt.h:117
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_simt.h:95
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Templates exposing architecture support for warp-level multiply-add operations.
Definition: mma_simt_tile_iterator.h:69
Defines a Shape template for matrix tiles.
arch::OpClassSimt OperatorClass
Indicates class of matrix operator.
Definition: mma_simt.h:101
typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved< 4 >, LayoutB >::value, layout::ColumnMajor, typename platform::conditional< platform::is_same< layout::RowMajorInterleaved< 4 >, LayoutB >::value, layout::RowMajor, LayoutB >::type >::type ThreadLayoutB
Definition: mma_simt.h:115
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_simt.h:83
Templates exposing architecture support for warp-level multiply-add operations.
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
std::conditional (true specialization)
Definition: platform.h:325
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_simt.h:98
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_simt.h:154
typename platform::conditional< use_dp4a, int8_t, bool >::type dp4a_type
Definition: mma_simt.h:122
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_simt.h:80
typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved< 4 >, LayoutA >::value, layout::ColumnMajor, typename platform::conditional< platform::is_same< layout::RowMajorInterleaved< 4 >, LayoutA >::value, layout::RowMajor, LayoutA >::type >::type ThreadLayoutA
Definition: mma_simt.h:108
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_simt.h:86
Basic include for CUTLASS.
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_simt.h:89
CUTLASS_DEVICE MmaSimt()
Ctor.
Definition: mma_simt.h:190
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_simt.h:168