CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_tensor_op_sm70.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
32 #pragma once
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
36 
37 #include "cutlass/numeric_types.h"
38 #include "cutlass/matrix_shape.h"
39 
40 #include "cutlass/arch/mma.h"
41 
42 #include "cutlass/gemm/gemm.h"
43 #include "cutlass/gemm/warp/mma.h"
44 
47 
49 
50 namespace cutlass {
51 namespace gemm {
52 namespace warp {
53 
55 
57 template <
59  typename Shape_,
61  typename ElementA_,
63  typename LayoutA_,
65  typename ElementB_,
67  typename LayoutB_,
69  typename ElementC_,
71  typename LayoutC_,
73  typename Policy_,
75  typename Enable = bool
76 >
78 public:
80  using Shape = Shape_;
81 
83  using ElementA = ElementA_;
84 
86  using LayoutA = LayoutA_;
87 
89  using ElementB = ElementB_;
90 
92  using LayoutB = LayoutB_;
93 
95  using ElementC = ElementC_;
96 
98  using LayoutC = LayoutC_;
99 
101  using Policy = Policy_;
102 
104  using OperatorClass = arch::OpClassTensorOp;
105 
107  static int const kThreadCount = 32;
108 
111 
112  static_assert(!(Shape::kM % InterleavedTileShape::kM) &&
113  !(Shape::kN % InterleavedTileShape::kN),
114  "Shape must be a multiple of InterleavedTileShape.");
115 public:
116 
120  Operand::kA,
121  ElementA,
122  LayoutA,
123  MatrixShape<
124  Policy::Operator::Shape::kM,
125  Policy::Operator::Shape::kK
126  >,
127  Policy::OpDelta::kRow,
128  kThreadCount
129  >;
130 
132  using FragmentA = typename IteratorA::Fragment;
133 
137  Operand::kB,
138  ElementB,
139  LayoutB,
140  MatrixShape<
141  Policy::Operator::Shape::kK,
142  Policy::Operator::Shape::kN
143  >,
144  Policy::OpDelta::kRow,
145  kThreadCount
146  >;
147 
149  using FragmentB = typename IteratorB::Fragment;
150 
154  ElementC,
155  LayoutC,
156  typename Policy::Operator::Shape,
157  typename Policy::OpDelta
158  >;
159 
161  using FragmentC = typename IteratorC::Fragment;
162 
163 private:
164 
166  !(Shape::kM % Policy::Operator::Shape::kM) &&
167  !(Shape::kN % Policy::Operator::Shape::kN),
168  "Shape of warp-level Mma must be divisible by operator shape.");
169 
171  using MmaIterations = MatrixShape<
172  InterleavedTileShape::kM / Policy::Operator::Shape::kM,
173  InterleavedTileShape::kN / Policy::Operator::Shape::kN
174  >;
175  using TileIterations = MatrixShape<
176  Shape::kM / InterleavedTileShape::kM,
177  Shape::kN / InterleavedTileShape::kN
178  >;
179 
180  // Whether matrix B is reordered
181  bool reorder_B_;
182 
183 public:
184 
186  typename Policy::Operator mma;
187 
188 public:
189 
190  //
191  // Methods
192  //
193 
195  CUTLASS_DEVICE
197 
199  CUTLASS_DEVICE
201  FragmentC &D,
202  FragmentA const &A,
203  FragmentB const &B,
204  FragmentC const &C,
205  int const &partitionN_idx = 0) {
206 
207  using MmaOperandA = typename Policy::Operator::FragmentA;
208  using MmaOperandB = typename Policy::Operator::FragmentB;
209  using MmaOperandC = typename Policy::Operator::FragmentC;
210 
211  D = C;
212 
213  MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
214  MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
215  MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
216 
218  for (int outer_col = 0; outer_col < TileIterations::kColumn; ++outer_col) {
220  for (int inner_col = 0; inner_col < MmaIterations::kColumn; ++inner_col) {
222  for (int outer_row = 0; outer_row < TileIterations::kRow; ++outer_row) {
224 
225  for (int inner_row = 0; inner_row < MmaIterations::kRow; ++inner_row) {
226 
227  int op_col = inner_col + MmaIterations::kColumn * outer_col;
228 
229  // Column-major serpentine sequence to maximize reuse of A operand.
230  int inner_row_serp = inner_row;
231  int outer_row_serp = outer_row;
232  if (op_col & 1) {
233  inner_row_serp = MmaIterations::kRow - inner_row - 1;
234  outer_row_serp = TileIterations::kRow - outer_row - 1;
235  }
236  int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp;
237  int op_idx = inner_row_serp + MmaIterations::kRow *
238  (inner_col + MmaIterations::kColumn *
239  (outer_row_serp + TileIterations::kRow * outer_col));
240  mma(
241  ptr_D[op_idx],
242  ptr_A[op_row],
243  ptr_B[op_col],
244  ptr_D[op_idx]);
245 
246  }
247  }
248  }
249  }
250  }
251 };
252 
254 
255 } // namespace warp
256 } // namespace gemm
257 } // namespace cutlass
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_tensor_op_sm70.h:101
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_tensor_op_sm70.h:149
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_tensor_op_sm70.h:92
Definition: mma_tensor_op_tile_iterator_sm70.h:70
Definition: aligned_buffer.h:35
static int const kColumn
columns of a matrix
Definition: matrix_shape.h:44
Defines common types used for all GEMM-like operators.
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_tensor_op_sm70.h:80
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_tensor_op_sm70.h:104
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_tensor_op_sm70.h:86
Array< Element, Shape::kCount/kThreads > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_tensor_op_tile_iterator_sm70.h:1213
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_tensor_op_sm70.h:83
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Templates exposing architecture support for multiply-add operations.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
Defines a Shape template for matrix tiles.
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_tensor_op_sm70.h:89
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0)
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_tensor_op_sm70.h:200
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma)
Definition: mma_tensor_op_sm70.h:186
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op_sm70.h:77
static int const kRow
rows of a matrix
Definition: matrix_shape.h:43
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
typename IteratorC::Fragment FragmentC
Storage for C tile.
Definition: mma_tensor_op_sm70.h:161
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_tensor_op_sm70.h:98
static int const kThreadCount
Number of threads participating in warp-level matrix product.
Definition: mma_tensor_op_sm70.h:107
CUTLASS_DEVICE MmaVoltaTensorOp()
Ctor.
Definition: mma_tensor_op_sm70.h:196
Definition: mma_tensor_op_tile_iterator_sm70.h:1135
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_tensor_op_sm70.h:95
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_tensor_op_sm70.h:132
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
static int const kN
Definition: include/cutlass/gemm/gemm.h:59