CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_tensor_op.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/array.h"
34 
35 #include "cutlass/numeric_types.h"
36 #include "cutlass/matrix_shape.h"
37 
39 #include "cutlass/arch/mma_sm75.h"
40 #include "cutlass/gemm/gemm.h"
41 #include "cutlass/gemm/warp/mma.h"
42 
44 
47 
48 namespace cutlass {
49 namespace gemm {
50 namespace warp {
51 
53 
55 template <
57  typename Shape_,
59  typename ElementA_,
61  typename LayoutA_,
63  typename ElementB_,
65  typename LayoutB_,
67  typename ElementC_,
69  typename LayoutC_,
71  typename Policy_,
73  int PartitionsK_ = 1,
76  bool AccumulatorsInRowMajor = false,
78  int PartitionsN_ = 1,
80  typename Enable = bool
81 >
82 class MmaTensorOp {
83 public:
85  using Shape = Shape_;
86 
88  using ElementA = ElementA_;
89 
91  using LayoutA = LayoutA_;
92 
94  using ElementB = ElementB_;
95 
97  using LayoutB = LayoutB_;
98 
100  using ElementC = ElementC_;
101 
103  using LayoutC = LayoutC_;
104 
106  using Policy = Policy_;
107 
109  using OperatorClass = arch::OpClassTensorOp;
110 
112  static int const kThreadCount = 32;
113 
115  static int const kPartitionsK = PartitionsK_;
116 
118  static int const kPartitionsN = PartitionsN_;
119 
120 public:
121 
126  Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
127 
129  using FragmentA = typename IteratorA::Fragment;
130 
135  Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
136 
138  using FragmentB = typename IteratorB::Fragment;
139 
143  typename Policy::Operator::Shape, typename Policy::OpDelta>;
144 
146  using FragmentC = typename IteratorC::Fragment;
147 
148 private:
149 
151  !(Shape::kM % Policy::Operator::Shape::kM) &&
152  !(Shape::kN % Policy::Operator::Shape::kN),
153  "Shape of warp-level Mma must be divisible by operator shape.");
154 
156  using MmaIterations = MatrixShape<
157  Shape::kM / Policy::Operator::Shape::kM,
158  (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
159  Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
160  1
161  >;
162 
163 public:
164 
166  typename Policy::Operator mma;
167 
168 public:
169 
170  //
171  // Methods
172  //
173 
175  CUTLASS_DEVICE
177 
179  CUTLASS_DEVICE
181  FragmentC &D,
182  FragmentA const &A,
183  FragmentB const &B,
184  FragmentC const &C,
185  int const &partitionN_idx = 0) const {
186 
187  using MmaOperandA = typename Policy::Operator::FragmentA;
188  using MmaOperandB = typename Policy::Operator::FragmentB;
189  using MmaOperandC = typename Policy::Operator::FragmentC;
190 
191  D = C;
192 
193  MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
194  MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
195  MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
196 
197  // The offset of multilicand B for current partition
198  const int n_off = partitionN_idx * FragmentB::kElements / MmaOperandB::kElements / kPartitionsN;
199  // Serpentine visitation order maximizing reuse of Rb
201  for (int n = 0; n < MmaIterations::kColumn; ++n) {
202 
204  for (int m = 0; m < MmaIterations::kRow; ++m) {
205 
206  int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
207 
208  if (AccumulatorsInRowMajor) { // matrix B is reordered
209  mma(
210  ptr_D[n + m_serpentine * MmaIterations::kColumn],
211  ptr_A[m_serpentine],
212  ptr_B[n],
213  ptr_D[n + m_serpentine * MmaIterations::kColumn]);
214  } else {
215  mma(
216  ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow],
217  ptr_A[m_serpentine],
218  ptr_B[n + n_off],
219  ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow]);
220  }
221  }
222  }
223  }
224 };
225 
227 
228 } // namespace warp
229 } // namespace gemm
230 } // namespace cutlass
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_tensor_op.h:129
Definition: aligned_buffer.h:35
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_tensor_op.h:97
CUTLASS_DEVICE MmaTensorOp()
Ctor.
Definition: mma_tensor_op.h:176
Architecture-specific operators on memory added for SM75.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
Defines common types used for all GEMM-like operators.
static int const kThreadCount
Number of threads participating in warp-level matrix product.
Definition: mma_tensor_op.h:112
static int const kPartitionsN
PartitionsN indicating how many PartitionsN for multiplicand B.
Definition: mma_tensor_op.h:118
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op.h:82
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_tensor_op.h:91
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_tensor_op.h:138
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
typename IteratorC::Fragment FragmentC
Storage for C tile.
Definition: mma_tensor_op.h:146
Definition: mma_tensor_op_tile_iterator.h:1794
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_tensor_op.h:180
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma)
Definition: mma_tensor_op.h:166
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_tensor_op.h:100
Top-level include for all CUTLASS numeric types.
Definition: mma_tensor_op_tile_iterator.h:75
#define static_assert(__e, __m)
Definition: platform.h:153
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_tensor_op.h:103
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_tensor_op.h:106
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_tensor_op.h:85
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_tensor_op.h:94
static int const kPartitionsK
Number of partitions along K dimension.
Definition: mma_tensor_op.h:115
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_tensor_op.h:109
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_tensor_op.h:88
Matrix multiply for SM75.
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.