CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_pipelined.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/aligned_buffer.h"
35 
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38 
39 #include "cutlass/gemm/gemm.h"
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace threadblock {
47 
49 
51 template <
53  typename Shape_,
55  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
56  typename IteratorA_,
59  typename SmemIteratorA_,
61  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
62  typename IteratorB_,
65  typename SmemIteratorB_,
67  typename ElementC_,
69  typename LayoutC_,
71  typename Policy_,
73  typename TransformA_ = NumericArrayConverter<
74  typename SmemIteratorA_::Element,
75  typename IteratorA_::Element,
76  IteratorA_::Fragment::kElements>,
79  typename TransformB_ = NumericArrayConverter<
80  typename SmemIteratorB_::Element,
81  typename IteratorB_::Element,
82  IteratorB_::Fragment::kElements>,
84  typename Enable = bool
85 >
86 class MmaPipelined : public MmaBase<Shape_, Policy_, 2> {
87 public:
88 
91 
92  using Shape = Shape_;
93  using IteratorA = IteratorA_;
94  using IteratorB = IteratorB_;
95  using ElementC = ElementC_;
96  using LayoutC = LayoutC_;
97  using Policy = Policy_;
98 
99  using SmemIteratorA = SmemIteratorA_;
100  using SmemIteratorB = SmemIteratorB_;
101 
102  using TransformA = TransformA_;
103  using TransformB = TransformB_;
104 
105  //
106  // Dependent types
107  //
108 
110  using FragmentA = typename IteratorA::Fragment;
111 
113  using FragmentB = typename IteratorB::Fragment;
114 
116  using FragmentC = typename Policy::Operator::FragmentC;
117 
119  using Operator = typename Policy::Operator;
120 
121  // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
122  static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
123 
124 private:
125 
126  using WarpFragmentA = typename Operator::FragmentA;
127  using WarpFragmentB = typename Operator::FragmentB;
128 
129 protected:
130 
133 
136 
137 public:
138 
140  CUTLASS_DEVICE
142  typename Base::SharedStorage &shared_storage,
143  int thread_idx,
144  int warp_idx,
145  int lane_idx
146  ):
147  Base(shared_storage, thread_idx, warp_idx, lane_idx),
148  smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
149  smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
150 
151  // Compute warp location within threadblock tile by mapping the warp_id to
152  // three coordinates:
153  // _m: the warp's position within the threadblock along the M dimension
154  // _n: the warp's position within the threadblock along the N dimension
155  // _k: the warp's position within the threadblock along the K dimension
156 
157  int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
158  int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
159 
160  int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
161  int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
162 
163  // Add per-warp offsets in units of warp-level tiles
164  this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
165  this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
166  }
167 
169  CUTLASS_DEVICE
171  int gemm_k_iterations,
172  FragmentC &accum,
173  IteratorA iterator_A,
174  IteratorB iterator_B,
175  FragmentC const &src_accum,
176  TransformA transform_A = TransformA(),
177  TransformB transform_B = TransformB()) {
178 
179  //
180  // Prologue
181  //
182 
183  // Perform accumulation in the 'd' output operand
184  accum = src_accum;
185 
186  FragmentA tb_frag_A;
187  FragmentB tb_frag_B;
188 
189  tb_frag_A.clear();
190  tb_frag_B.clear();
191 
192  // The last kblock is loaded in the prolog
193  iterator_A.load(tb_frag_A);
194  iterator_B.load(tb_frag_B);
195 
196  ++iterator_A;
197  ++iterator_B;
198 
199  this->smem_iterator_A_.store(transform_A(tb_frag_A));
200  this->smem_iterator_B_.store(transform_B(tb_frag_B));
201 
202  ++this->smem_iterator_A_;
203  ++this->smem_iterator_B_;
204 
205  __syncthreads();
206 
207  // Pair of fragments used to overlap shared memory loads and math instructions
208  WarpFragmentA warp_frag_A[2];
209  WarpFragmentB warp_frag_B[2];
210 
211  this->warp_tile_iterator_A_.set_kgroup_index(0);
212  this->warp_tile_iterator_B_.set_kgroup_index(0);
213 
214  this->warp_tile_iterator_A_.load(warp_frag_A[0]);
215  this->warp_tile_iterator_B_.load(warp_frag_B[0]);
216 
217  ++this->warp_tile_iterator_A_;
218  ++this->warp_tile_iterator_B_;
219 
220  Operator warp_mma;
221 
222  int smem_write_stage_idx = 1;
223 
224  // Avoid reading out of bounds
225  if (gemm_k_iterations <= 1) {
226  iterator_A.clear_mask();
227  iterator_B.clear_mask();
228  }
229 
230  // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
231  // shared memory loads (which have the tighest latency requirement).
232 
233  //
234  // Mainloop
235  //
236 
237  // Note: The main loop does not support Base::kWarpGemmIterations == 2.
239  for (; gemm_k_iterations > 0; --gemm_k_iterations) {
240  //
241  // Loop over GEMM K dimension
242  //
243 
245  for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
246 
247  // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
248  // as the case may be.
249 
250  if (warp_mma_k == Base::kWarpGemmIterations - 1) {
251 
252  // Write fragments to shared memory
253  this->smem_iterator_A_.store(transform_A(tb_frag_A));
254 
255  this->smem_iterator_B_.store(transform_B(tb_frag_B));
256 
257  __syncthreads();
258 
259  ++this->smem_iterator_B_;
260  ++this->smem_iterator_A_;
261 
262  // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
263  if (smem_write_stage_idx == 1) {
264  this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
265  this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
266  }
267  else {
268  this->warp_tile_iterator_A_.add_tile_offset(
269  {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
270  this->warp_tile_iterator_B_.add_tile_offset(
271  {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
272  0});
273  }
274 
275  smem_write_stage_idx ^= 1;
276  }
277 
278  this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
279  this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
280 
281  this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
282  this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
283 
284  ++this->warp_tile_iterator_A_;
285  ++this->warp_tile_iterator_B_;
286 
287  if (warp_mma_k == 0) {
288 
289  iterator_A.load(tb_frag_A);
290  iterator_B.load(tb_frag_B);
291 
292  ++iterator_A;
293  ++iterator_B;
294 
295  // Avoid reading out of bounds if this was the last loop iteration
296  if (gemm_k_iterations <= 2) {
297  iterator_A.clear_mask();
298  iterator_B.clear_mask();
299  }
300  }
301 
302  warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
303  }
304  }
305 
306  }
307 };
308 
310 
311 } // namespace threadblock
312 } // namespace gemm
313 } // namespace cutlass
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
LayoutC_ LayoutC
Layout of accumulator matrix.
Definition: mma_pipelined.h:96
TransformB_ TransformB
Definition: mma_pipelined.h:103
Definition: aligned_buffer.h:35
Policy_ Policy
Policy describing tuning details.
Definition: mma_pipelined.h:97
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_pipelined.h:86
IteratorB_ IteratorB
Iterates over tiles of B operand in global memory.
Definition: mma_pipelined.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum, TransformA transform_A=TransformA(), TransformB transform_B=TransformB())
Perform a threadblock-scoped matrix multiply-accumulate.
Definition: mma_pipelined.h:170
IteratorA_ IteratorA
Iterates over tiles of A operand in global memory.
Definition: mma_pipelined.h:93
typename IteratorB::Fragment FragmentB
Fragment of operand B loaded from global memory.
Definition: mma_pipelined.h:113
SmemIteratorA_ SmemIteratorA
Definition: mma_pipelined.h:99
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines a Shape template for matrix tiles.
static int const kWarpGemmIterations
Number of warp-level GEMM oeprations.
Definition: mma_base.h:108
Template for a double-buffered threadblock-scoped GEMM kernel.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: mma_pipelined.h:92
static int const kStages
Number of stages.
Definition: mma_base.h:112
typename IteratorA::Fragment FragmentA
Fragment of operand A loaded from global memory.
Definition: mma_pipelined.h:110
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
Definition: mma_base.h:83
typename Policy::Operator::FragmentC FragmentC
Fragment of accumulator tile.
Definition: mma_pipelined.h:116
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
#define CUTLASS_GEMM_LOOP
Definition: cutlass.h:112
ElementC_ ElementC
Data type of accumulator matrix.
Definition: mma_pipelined.h:95
SmemIteratorA smem_iterator_A_
Iterator to write threadblock-scoped tile of A operand to shared memory.
Definition: mma_pipelined.h:132
SmemIteratorB_ SmemIteratorB
Definition: mma_pipelined.h:100
SmemIteratorB smem_iterator_B_
Iterator to write threadblock-scoped tile of B operand to shared memory.
Definition: mma_pipelined.h:135
CUTLASS_DEVICE MmaPipelined(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_pipelined.h:141
Basic include for CUTLASS.
TransformA_ TransformA
Definition: mma_pipelined.h:102
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_pipelined.h:119
static int const kN
Definition: include/cutlass/gemm/gemm.h:59