CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_tensor_op_wmma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/arch/wmma.h"
34 
35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
36 
37 #include "cutlass/wmma_array.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/matrix_shape.h"
40 
42 #include "cutlass/arch/mma_sm75.h"
43 #include "cutlass/gemm/gemm.h"
44 #include "cutlass/gemm/warp/mma.h"
45 
47 
49 
51 
52 namespace cutlass {
53 namespace gemm {
54 namespace warp {
55 
57 
59 template <
61  typename Shape_,
63  typename ElementA_,
65  typename LayoutA_,
67  typename ElementB_,
69  typename LayoutB_,
71  typename ElementC_,
73  typename LayoutC_,
75  typename Policy_,
77  int PartitionsK_ = 1,
79  int PartitionsN_ = 1,
81  typename Enable = bool
82 >
83 class MmaTensorOpWmma {
84 public:
86  using Shape = Shape_;
87 
89  using ElementA = ElementA_;
90 
92  using LayoutA = LayoutA_;
93 
95  using ElementB = ElementB_;
96 
98  using LayoutB = LayoutB_;
99 
101  using ElementC = ElementC_;
102 
104  using LayoutC = LayoutC_;
105 
107  using Policy = Policy_;
108 
110  using OperatorClass = arch::OpClassTensorOp;
111 
113  static int const kThreadCount = 32;
114 
116  static int const kPartitionsK = PartitionsK_;
117 
119  static int const kPartitionsN = PartitionsN_;
120 
121 public:
122 
124  using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator<
125  MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
126  Policy::OpDelta::kRow, kThreadCount, Policy>;
127 
129  using FragmentA = typename IteratorA::Fragment;
130 
132  using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator<
133  MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
134  Policy::OpDelta::kRow, kThreadCount, Policy>;
135 
137  using FragmentB = typename IteratorB::Fragment;
138 
140  using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator<
141  MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
142  typename Policy::OpDelta, Policy>;
143 
145  using FragmentC = typename IteratorC::Fragment;
146 
147 private:
148 
150  !(Shape::kM % Policy::Operator::Shape::kM) &&
151  !(Shape::kN % Policy::Operator::Shape::kN),
152  "Shape of warp-level Wmma must be divisible by operator shape (wmma native size)");
153 
155  using WmmaIterations = MatrixShape<
156  Shape::kM / Policy::Operator::Shape::kM,
157  (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
158  Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
159  1
160  >;
161 
162 public:
163 
165  typename Policy::Operator wmma;
166 
167 public:
168 
169  //
170  // Methods
171  //
172 
174  CUTLASS_DEVICE
175  MmaTensorOpWmma() {}
176 
178  CUTLASS_DEVICE
179  void operator()(
180  FragmentC &D,
181  FragmentA const &A,
182  FragmentB const &B,
183  FragmentC const &C,
184  int const &partitionN_idx = 0) const {
185 
187  for (int n = 0; n < WmmaIterations::kColumn; ++n) {
189  for (int m = 0; m < WmmaIterations::kRow; ++m) {
190 
191  // accumulate wmma mma
192  wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]);
193  }
194  }
195  }
196 
197 };
198 
200 
201 } // namespace warp
202 } // namespace gemm
203 } // namespace cutlass
204 
205 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
206 
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
Architecture-specific operators on memory added for SM75.
Defines common types used for all GEMM-like operators.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
Matrix multiply for SM75.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.