CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
kernel/gemm_splitk_parallel.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 
33 #include "cutlass/gemm/gemm.h"
34 #include "cutlass/matrix_coord.h"
35 
37 
38 namespace cutlass {
39 namespace gemm {
40 namespace kernel {
41 
43 
44 template <
45  typename Mma_,
46  typename Epilogue_,
47  typename ThreadblockSwizzle_
48 >
50 
51  using Mma = Mma_;
52  using Epilogue = Epilogue_;
53  using OutputOp = typename Epilogue::OutputOp;
54  using ThreadblockSwizzle = ThreadblockSwizzle_;
55 
57  using WarpCount = typename Mma::WarpCount;
58  static int const kThreadCount = 32 * WarpCount::kCount;
59 
60  static int const kAlignmentK = Mma::Operator::Shape::kK;
61 
63  struct Params {
66  typename Mma::IteratorA::Params params_A;
67  typename Mma::IteratorA::TensorRef ref_A;
68  typename Mma::IteratorB::Params params_B;
69  typename Mma::IteratorB::TensorRef ref_B;
70  typename Epilogue::OutputTileIterator::Params params_D;
71  typename Epilogue::OutputTileIterator::TensorRef ref_D;
72  typename OutputOp::Params output_op;
75 
76  //
77  // Methods
78  //
79 
81  Params() { }
82 
85  cutlass::gemm::GemmCoord const & problem_size,
86  cutlass::gemm::GemmCoord const & grid_tiled_shape,
87  typename Mma::IteratorA::TensorRef ref_A,
88  typename Mma::IteratorB::TensorRef ref_B,
89  typename Epilogue::OutputTileIterator::TensorRef ref_D,
90  typename OutputOp::Params output_op,
91  int64_t splitk_slice_stride
92  ):
93  problem_size(problem_size),
94  grid_tiled_shape(grid_tiled_shape),
95  params_A(ref_A.layout()),
96  ref_A(ref_A),
97  params_B(ref_B.layout()),
98  ref_B(ref_B),
99  params_D(ref_D.layout()),
100  ref_D(ref_D),
101  output_op(output_op),
102  splitk_slice_stride(splitk_slice_stride) {
103 
104  int full_gemm_k_iterations = problem_size.k() / Mma::Shape::kK;
105  int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
106 
107  gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
108  }
109  };
110 
113  typename Mma::SharedStorage main_loop;
114  typename Epilogue::SharedStorage epilogue;
115  };
116 
117  //
118  // Methods
119  //
120 
123 
125  CUTLASS_DEVICE
126  void operator()(Params const &params, SharedStorage &shared_storage) {
127 
128  // Compute threadblock location
129  ThreadblockSwizzle threadblock_swizzle;
130 
131  cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
132 
133  // Early exit if CTA is out of range
134  if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
135  params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
136 
137  return;
138  }
139 
140  // Compute initial location in logical coordinates
141  cutlass::MatrixCoord tb_offset_A{
142  threadblock_tile_offset.m() * Mma::Shape::kM,
143  threadblock_tile_offset.k() * params.gemm_k_size,
144  };
145 
146  cutlass::MatrixCoord tb_offset_B{
147  threadblock_tile_offset.k() * params.gemm_k_size,
148  threadblock_tile_offset.n() * Mma::Shape::kN
149  };
150 
151  // Problem size is a function of threadblock index in the K dimension
152  int problem_size_k;
153  if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) {
154  problem_size_k = params.problem_size.k();
155  }
156  else {
157  problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
158  }
159 
160  // Compute threadblock-scoped matrix multiply-add
161  int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
162 
163  // Compute position within threadblock
164  int thread_idx = threadIdx.x;
165 
166  // Construct iterators to A and B operands
167  typename Mma::IteratorA iterator_A(
168  params.params_A,
169  params.ref_A.data(),
170  {params.problem_size.m(), problem_size_k},
171  thread_idx,
172  tb_offset_A);
173 
174  typename Mma::IteratorB iterator_B(
175  params.params_B,
176  params.ref_B.data(),
177  {problem_size_k, params.problem_size.n()},
178  thread_idx,
179  tb_offset_B);
180 
181  int warp_idx = threadIdx.x / 32;
182  int lane_idx = threadIdx.x % 32;
183 
184 
185  //
186  // Main loop
187  //
188 
189  // Construct thread-scoped matrix multiply
190  Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
191 
192  typename Mma::FragmentC accumulators;
193 
194  accumulators.clear();
195 
196  mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
197 
198  //
199  // Epilogue
200  //
201 
202  OutputOp output_op(params.output_op);
203 
204  //
205  // Masked tile iterators constructed from members
206  //
207 
208  threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
209 
210  //assume identity swizzle
211  MatrixCoord threadblock_offset(
212  threadblock_tile_offset.m() * Mma::Shape::kM,
213  threadblock_tile_offset.n() * Mma::Shape::kN
214  );
215 
216  // Tile iterator writing to output tile
217  typename Epilogue::OutputTileIterator iterator_D(
218  params.params_D,
219  params.ref_D.data(),
220  params.problem_size.mn(),
221  thread_idx,
222  threadblock_offset
223  );
224 
225  iterator_D.add_pointer_offset(params.splitk_slice_stride * threadblock_tile_offset.k());
226 
227  // Execute the epilogue
228  Epilogue epilogue(
229  shared_storage.epilogue,
230  thread_idx,
231  warp_idx,
232  lane_idx);
233 
234  // Run efficient epilogue
235  epilogue(output_op, iterator_D, accumulators, iterator_D);
236  }
237 };
238 
240 
241 } // namespace kernel
242 } // namespace gemm
243 } // namespace cutlass
244 
CUTLASS_DEVICE void operator()(Params const &params, SharedStorage &shared_storage)
Executes one GEMM.
Definition: kernel/gemm_splitk_parallel.h:126
CUTLASS_HOST_DEVICE GemmSplitKParallel()
Definition: kernel/gemm_splitk_parallel.h:122
Definition: aligned_buffer.h:35
Epilogue_ Epilogue
Definition: kernel/gemm_splitk_parallel.h:52
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_splitk_parallel.h:64
Shared memory storage structure.
Definition: kernel/gemm_splitk_parallel.h:112
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_splitk_parallel.h:114
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_splitk_parallel.h:65
static int const kThreadCount
Definition: kernel/gemm_splitk_parallel.h:58
Mma::SharedStorage main_loop
Definition: kernel/gemm_splitk_parallel.h:113
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Parameters structure.
Definition: kernel/gemm_splitk_parallel.h:63
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: kernel/gemm_splitk_parallel.h:57
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_splitk_parallel.h:54
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op, int64_t splitk_slice_stride)
Definition: kernel/gemm_splitk_parallel.h:84
OutputOp::Params output_op
Definition: kernel/gemm_splitk_parallel.h:72
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_splitk_parallel.h:67
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_splitk_parallel.h:69
int gemm_k_size
Definition: kernel/gemm_splitk_parallel.h:74
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_splitk_parallel.h:71
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_splitk_parallel.h:81
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_splitk_parallel.h:70
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Mma::IteratorA::Params params_A
Definition: kernel/gemm_splitk_parallel.h:66
static int const kAlignmentK
Definition: kernel/gemm_splitk_parallel.h:60
Defines a canonical coordinate for rank=2 matrices offering named indices.
Definition: kernel/gemm_splitk_parallel.h:49
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Mma_ Mma
Definition: kernel/gemm_splitk_parallel.h:51
Mma::IteratorB::Params params_B
Definition: kernel/gemm_splitk_parallel.h:68
int64_t splitk_slice_stride
Definition: kernel/gemm_splitk_parallel.h:73
Basic include for CUTLASS.
Definition: matrix_coord.h:39
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_splitk_parallel.h:53