CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_singlestage.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/aligned_buffer.h"
34 
35 #include "cutlass/numeric_types.h"
36 #include "cutlass/matrix_shape.h"
37 
38 #include "cutlass/gemm/gemm.h"
40 
41 
42 
44 
45 namespace cutlass {
46 namespace gemm {
47 namespace threadblock {
48 
50 
52 template <
54  typename Shape_,
56  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
57  typename IteratorA_,
60  typename SmemIteratorA_,
62  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
63  typename IteratorB_,
66  typename SmemIteratorB_,
68  typename ElementC_,
70  typename LayoutC_,
72  typename Policy_,
74  typename Enable = bool
75 >
76 class MmaSingleStage : public MmaBase<Shape_, Policy_, 1> {
77 public:
78 
81 
82  using Shape = Shape_;
83  using IteratorA = IteratorA_;
84  using IteratorB = IteratorB_;
85  using ElementC = ElementC_;
86  using LayoutC = LayoutC_;
87  using Policy = Policy_;
88 
89  using SmemIteratorA = SmemIteratorA_;
90  using SmemIteratorB = SmemIteratorB_;
91 
92  //
93  // Dependent types
94  //
95 
97  using FragmentA = typename IteratorA::Fragment;
98 
100  using FragmentB = typename IteratorB::Fragment;
101 
103  using FragmentC = typename Policy::Operator::FragmentC;
104 
106  using Operator = typename Policy::Operator;
107 
108  // staticaly assert kStages for MmaSingleStage is 1 (single stage mma pipeline)
109  static_assert((Base::kStages==1), "MmaSingleStage requires kStages set to value 1");
110 private:
111 
112  using WarpFragmentA = typename Operator::FragmentA;
113  using WarpFragmentB = typename Operator::FragmentB;
114 
115 protected:
116 
119 
122 
123 public:
124 
126  CUTLASS_DEVICE
128  typename Base::SharedStorage &shared_storage,
129  int thread_idx,
130  int warp_idx,
131  int lane_idx
132  ):
133  Base(shared_storage, thread_idx, warp_idx, lane_idx),
134  smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
135  smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
136 
137  // Compute warp location within threadblock tile by mapping the warp_id to
138  // three coordinates:
139  // _m: the warp's position within the threadblock along the M dimension
140  // _n: the warp's position within the threadblock along the N dimension
141  // _k: the warp's position within the threadblock along the K dimension
142 
143  int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
144  int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
145 
146  int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
147  int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
148 
149  // Add per-warp offsets in units of warp-level tiles
150  this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
151  this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
152 
153  }
154 
156  CUTLASS_DEVICE
158  int gemm_k_iterations,
159  FragmentC &accum,
160  IteratorA iterator_A,
161  IteratorB iterator_B,
162  FragmentC const &src_accum) {
163 
164  //
165  // Prologue
166  //
167 
168  // Perform accumulation in the 'd' output operand
169  accum = src_accum;
170 
171 
172  FragmentA tb_frag_A;
173  FragmentB tb_frag_B;
174 
175  tb_frag_A.clear();
176  tb_frag_B.clear();
177 
178  // The last kblock is loaded in the prolog
179  iterator_A.load(tb_frag_A);
180  iterator_B.load(tb_frag_B);
181 
182  ++iterator_A;
183  ++iterator_B;
184 
185  // Pair of fragments used to overlap shared memory loads and math instructions
186  WarpFragmentA warp_frag_A[2];
187  WarpFragmentB warp_frag_B[2];
188  Operator warp_mma;
189 
190  // Avoid reading out of bounds
191  if (gemm_k_iterations <= 1) {
192  iterator_A.clear_mask();
193  iterator_B.clear_mask();
194  }
195 
196 
197  //
198  // Mainloop
199  //
200 
202  for (; gemm_k_iterations > 0; --gemm_k_iterations) {
203  this->smem_iterator_A_.store(tb_frag_A);
204  this->smem_iterator_B_.store(tb_frag_B);
205 
206 
207  __syncthreads();
208 
209  //
210  // Loop over GEMM K dimension
211  //
212 
214  for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
215 
216  // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
217  // as the case may be.
218 
219  this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations);
220  this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations);
221 
222  this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k) % 2]);
223  this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k) % 2]);
224 
225  ++this->warp_tile_iterator_A_;
226  ++this->warp_tile_iterator_B_;
227 
228  warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
229  }
230 
231  // Add negative offsets to return smem load iterators to the 'start' of the shared memory
232  this->warp_tile_iterator_A_.add_tile_offset({0, -Policy::kPartitionsK * Base::kWarpGemmIterations});
233  this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
234 
235  __syncthreads();
236 
237  iterator_A.load(tb_frag_A);
238  iterator_B.load(tb_frag_B);
239 
240  ++iterator_A;
241  ++iterator_B;
242 
243  // Avoid reading out of bounds if this was the last loop iteration
244  if (gemm_k_iterations <= 2) {
245  iterator_A.clear_mask();
246  iterator_B.clear_mask();
247  }
248  }
249 
250  }
251 };
252 
254 
255 } // namespace threadblock
256 } // namespace gemm
257 } // namespace cutlass
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
IteratorB_ IteratorB
Iterates over tiles of B operand in global memory.
Definition: mma_singlestage.h:84
Definition: aligned_buffer.h:35
ElementC_ ElementC
Data type of accumulator matrix.
Definition: mma_singlestage.h:85
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
SmemIteratorB smem_iterator_B_
Iterator to write threadblock-scoped tile of B operand to shared memory.
Definition: mma_singlestage.h:121
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_singlestage.h:76
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum)
Perform a threadblock-scoped matrix multiply-accumulate.
Definition: mma_singlestage.h:157
Policy_ Policy
Policy describing tuning details.
Definition: mma_singlestage.h:87
typename IteratorA::Fragment FragmentA
Fragment of operand A loaded from global memory.
Definition: mma_singlestage.h:97
LayoutC_ LayoutC
Layout of accumulator matrix.
Definition: mma_singlestage.h:86
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: mma_singlestage.h:82
CUTLASS_DEVICE MmaSingleStage(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_singlestage.h:127
Defines a Shape template for matrix tiles.
SmemIteratorA smem_iterator_A_
Iterator to write threadblock-scoped tile of A operand to shared memory.
Definition: mma_singlestage.h:118
static int const kWarpGemmIterations
Number of warp-level GEMM oeprations.
Definition: mma_base.h:108
Template for a double-buffered threadblock-scoped GEMM kernel.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
static int const kStages
Number of stages.
Definition: mma_base.h:112
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
SmemIteratorB_ SmemIteratorB
Definition: mma_singlestage.h:90
typename Policy::Operator::FragmentC FragmentC
Fragment of accumulator tile.
Definition: mma_singlestage.h:103
Definition: mma_base.h:83
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_singlestage.h:106
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
#define CUTLASS_GEMM_LOOP
Definition: cutlass.h:112
SmemIteratorA_ SmemIteratorA
Definition: mma_singlestage.h:89
typename IteratorB::Fragment FragmentB
Fragment of operand B loaded from global memory.
Definition: mma_singlestage.h:100
IteratorA_ IteratorA
Iterates over tiles of A operand in global memory.
Definition: mma_singlestage.h:83
Basic include for CUTLASS.
static int const kN
Definition: include/cutlass/gemm/gemm.h:59