CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
include/cutlass/gemm/kernel/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 
34 #include "cutlass/gemm/gemm.h"
35 #include "cutlass/matrix_coord.h"
36 #include "cutlass/semaphore.h"
37 
39 
40 namespace cutlass {
41 namespace gemm {
42 namespace kernel {
43 
45 
46 template <
47  typename Mma_,
48  typename Epilogue_,
49  typename ThreadblockSwizzle_,
50  bool SplitKSerial
51 >
52 struct Gemm {
53 
54  using Mma = Mma_;
55  using Epilogue = Epilogue_;
56  using OutputOp = typename Epilogue::OutputOp;
57  using ThreadblockSwizzle = ThreadblockSwizzle_;
58  static bool const kSplitKSerial = SplitKSerial;
59 
61  using WarpCount = typename Mma::WarpCount;
62  static int const kThreadCount = 32 * WarpCount::kCount;
63 
65  struct Params {
68  typename Mma::IteratorA::Params params_A;
69  typename Mma::IteratorA::TensorRef ref_A;
70  typename Mma::IteratorB::Params params_B;
71  typename Mma::IteratorB::TensorRef ref_B;
72  typename Epilogue::OutputTileIterator::Params params_C;
73  typename Epilogue::OutputTileIterator::TensorRef ref_C;
74  typename Epilogue::OutputTileIterator::Params params_D;
75  typename Epilogue::OutputTileIterator::TensorRef ref_D;
76  typename OutputOp::Params output_op;
77  int *semaphore;
80 
81  //
82  // Methods
83  //
84 
86  Params() { }
87 
90  cutlass::gemm::GemmCoord const & problem_size,
91  cutlass::gemm::GemmCoord const & grid_tiled_shape,
92  typename Mma::IteratorA::TensorRef ref_A,
93  typename Mma::IteratorB::TensorRef ref_B,
94  typename Epilogue::OutputTileIterator::TensorRef ref_C,
95  typename Epilogue::OutputTileIterator::TensorRef ref_D,
96  typename OutputOp::Params output_op = typename OutputOp::Params(),
97  int *semaphore = nullptr
98  ):
99  problem_size(problem_size),
100  grid_tiled_shape(grid_tiled_shape),
101  params_A(ref_A.layout()),
102  ref_A(ref_A),
103  params_B(ref_B.layout()),
104  ref_B(ref_B),
105  params_C(ref_C.layout()),
106  ref_C(ref_C),
107  params_D(ref_D.layout()),
108  ref_D(ref_D),
109  output_op(output_op),
110  semaphore(semaphore) {
111 
112  int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
113  int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
114 
115  gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
116  }
117  };
118 
121  typename Mma::SharedStorage main_loop;
122  typename Epilogue::SharedStorage epilogue;
123  };
124 
125  //
126  // Methods
127  //
128 
130  Gemm() { }
131 
134  cutlass::gemm::GemmCoord const & problem_size,
135  typename Mma::IteratorA::TensorRef ref_A,
136  typename Mma::IteratorB::TensorRef ref_B,
137  typename Epilogue::OutputTileIterator::TensorRef ref_C,
138  typename Epilogue::OutputTileIterator::TensorRef ref_D) {
139 
140  static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
141  static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
142  static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
143 
144  if (!TensorRef_aligned(ref_A, kAlignmentA)) {
146  }
147 
148  if (!TensorRef_aligned(ref_B, kAlignmentB)) {
150  }
151 
152  if (!TensorRef_aligned(ref_C, kAlignmentC)) {
154  }
155 
156  if (!TensorRef_aligned(ref_D, kAlignmentC)) {
158  }
159 
160  if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
161  (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
162  (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
163 
165  }
166 
167  return Status::kSuccess;
168  }
169 
171  CUTLASS_DEVICE
172  void operator()(Params const &params, SharedStorage &shared_storage) {
173 
174  // Compute threadblock location
175  ThreadblockSwizzle threadblock_swizzle;
176 
177  cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
178 
179  // Early exit if CTA is out of range
180  if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
181  params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
182 
183  return;
184  }
185 
186  // Compute initial location in logical coordinates
187  cutlass::MatrixCoord tb_offset_A{
188  threadblock_tile_offset.m() * Mma::Shape::kM,
189  threadblock_tile_offset.k() * params.gemm_k_size,
190  };
191 
192  cutlass::MatrixCoord tb_offset_B{
193  threadblock_tile_offset.k() * params.gemm_k_size,
194  threadblock_tile_offset.n() * Mma::Shape::kN
195  };
196 
197  // Problem size is a function of threadblock index in the K dimension
198  int problem_size_k = min(
199  params.problem_size.k(),
200  (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
201 
202  // Compute threadblock-scoped matrix multiply-add
203  int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
204 
205  // Compute position within threadblock
206  int thread_idx = threadIdx.x;
207 
208  // Construct iterators to A and B operands
209  typename Mma::IteratorA iterator_A(
210  params.params_A,
211  params.ref_A.data(),
212  {params.problem_size.m(), problem_size_k},
213  thread_idx,
214  tb_offset_A);
215 
216  typename Mma::IteratorB iterator_B(
217  params.params_B,
218  params.ref_B.data(),
219  {problem_size_k, params.problem_size.n()},
220  thread_idx,
221  tb_offset_B);
222 
223  int warp_idx = threadIdx.x / 32;
224  int lane_idx = threadIdx.x % 32;
225 
226  //
227  // Main loop
228  //
229 
230  // Construct thread-scoped matrix multiply
231  Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
232 
233  typename Mma::FragmentC accumulators;
234 
235  accumulators.clear();
236 
237  if (!kSplitKSerial || gemm_k_iterations > 0) {
238  // Compute threadblock-scoped matrix multiply-add
239  mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
240  }
241 
242  //
243  // Epilogue
244  //
245 
246  OutputOp output_op(params.output_op);
247 
248  //
249  // Masked tile iterators constructed from members
250  //
251 
252  threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
253 
254  //assume identity swizzle
255  MatrixCoord threadblock_offset(
256  threadblock_tile_offset.m() * Mma::Shape::kM,
257  threadblock_tile_offset.n() * Mma::Shape::kN
258  );
259 
260  int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
261 
262  // Construct the semaphore.
263  Semaphore semaphore(params.semaphore + block_idx, thread_idx);
264 
265  // If performing a reduction via split-K, fetch the initial synchronization
266  if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
267 
268  // Fetch the synchronization lock initially but do not block.
269  semaphore.fetch();
270 
271  // Indicate which position in a serial reduction the output operator is currently updating
272  output_op.set_k_partition(threadblock_tile_offset.k());
273  }
274 
275  // Tile iterator loading from source tensor.
276  typename Epilogue::OutputTileIterator iterator_C(
277  params.params_C,
278  params.ref_C.data(),
279  params.problem_size.mn(),
280  thread_idx,
281  threadblock_offset
282  );
283 
284  // Tile iterator writing to destination tensor.
285  typename Epilogue::OutputTileIterator iterator_D(
286  params.params_D,
287  params.ref_D.data(),
288  params.problem_size.mn(),
289  thread_idx,
290  threadblock_offset
291  );
292 
293  Epilogue epilogue(
294  shared_storage.epilogue,
295  thread_idx,
296  warp_idx,
297  lane_idx);
298 
299  // Wait on the semaphore - this latency may have been covered by iterator construction
300  if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
301 
302  // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
303  if (threadblock_tile_offset.k()) {
304  iterator_C = iterator_D;
305  }
306 
307  semaphore.wait(threadblock_tile_offset.k());
308 
309  __threadfence();
310  }
311 
312  // Execute the epilogue operator to update the destination tensor.
313  epilogue(output_op, iterator_D, accumulators, iterator_C);
314 
315  //
316  // Release the semaphore
317  //
318 
319  if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
320 
321  int lock = 0;
322  if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
323 
324  // The final threadblock resets the semaphore for subsequent grids.
325  lock = 0;
326  }
327  else {
328  // Otherwise, the semaphore is incremented
329  lock = threadblock_tile_offset.k() + 1;
330  }
331 
332  __threadfence();
333  semaphore.release(lock);
334  }
335  }
336 };
337 
339 
340 } // namespace kernel
341 } // namespace gemm
342 } // namespace cutlass
343 
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: include/cutlass/gemm/kernel/gemm.h:73
Definition: aligned_buffer.h:35
Epilogue::SharedStorage epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:122
Epilogue::OutputTileIterator::Params params_D
Definition: include/cutlass/gemm/kernel/gemm.h:74
Mma::IteratorA::Params params_A
Definition: include/cutlass/gemm/kernel/gemm.h:68
Epilogue_ Epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:55
Mma::IteratorB::Params params_B
Definition: include/cutlass/gemm/kernel/gemm.h:70
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op=typename OutputOp::Params(), int *semaphore=nullptr)
Definition: include/cutlass/gemm/kernel/gemm.h:89
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
Epilogue::OutputTileIterator::Params params_C
Definition: include/cutlass/gemm/kernel/gemm.h:72
static int const kThreadCount
Definition: include/cutlass/gemm/kernel/gemm.h:62
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void fetch()
Permit fetching the synchronization mechanism early.
Definition: semaphore.h:68
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: include/cutlass/gemm/kernel/gemm.h:67
int gemm_k_iterations
Definition: include/cutlass/gemm/kernel/gemm.h:78
Mma::IteratorB::TensorRef ref_B
Definition: include/cutlass/gemm/kernel/gemm.h:71
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D)
Determines whether kernel satisfies alignment.
Definition: include/cutlass/gemm/kernel/gemm.h:133
CUTLASS_HOST_DEVICE Gemm()
Definition: include/cutlass/gemm/kernel/gemm.h:130
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
static bool const kSplitKSerial
Definition: include/cutlass/gemm/kernel/gemm.h:58
typename Epilogue::OutputOp OutputOp
Definition: include/cutlass/gemm/kernel/gemm.h:56
Parameters structure.
Definition: include/cutlass/gemm/kernel/gemm.h:65
OutputOp::Params output_op
Definition: include/cutlass/gemm/kernel/gemm.h:76
operands fail alignment requirements.
Shared memory storage structure.
Definition: include/cutlass/gemm/kernel/gemm.h:120
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:183
int gemm_k_size
Definition: include/cutlass/gemm/kernel/gemm.h:79
int * semaphore
Definition: include/cutlass/gemm/kernel/gemm.h:77
CUTLASS_DEVICE void operator()(Params const &params, SharedStorage &shared_storage)
Executes one GEMM.
Definition: include/cutlass/gemm/kernel/gemm.h:172
CTA-wide semaphore for inter-CTA synchronization.
Definition: semaphore.h:48
Implementation of a CTA-wide semaphore for inter-CTA synchronization.
Defines a canonical coordinate for rank=2 matrices offering named indices.
CUTLASS_DEVICE void release(int status=0)
Updates the lock with the given result.
Definition: semaphore.h:98
cutlass::gemm::GemmCoord problem_size
Definition: include/cutlass/gemm/kernel/gemm.h:66
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: include/cutlass/gemm/kernel/gemm.h:57
Definition: include/cutlass/gemm/kernel/gemm.h:52
Mma::IteratorA::TensorRef ref_A
Definition: include/cutlass/gemm/kernel/gemm.h:69
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
CUTLASS_DEVICE void wait(int status=0)
Waits until the semaphore is equal to the given value.
Definition: semaphore.h:81
Operation was successful.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Mma_ Mma
Definition: include/cutlass/gemm/kernel/gemm.h:54
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: include/cutlass/gemm/kernel/gemm.h:61
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE Params()
Definition: include/cutlass/gemm/kernel/gemm.h:86
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Mma::SharedStorage main_loop
Definition: include/cutlass/gemm/kernel/gemm.h:121
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: include/cutlass/gemm/kernel/gemm.h:75