CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
kernel/gemm_batched.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 
33 #include "cutlass/gemm/gemm.h"
34 #include "cutlass/matrix_coord.h"
35 
37 
38 namespace cutlass {
39 namespace gemm {
40 namespace kernel {
41 
43 
44 template <
45  typename Mma_,
46  typename Epilogue_,
47  typename ThreadblockSwizzle_
48 >
49 struct GemmBatched {
50 
51  using Mma = Mma_;
52  using Epilogue = Epilogue_;
53  using OutputOp = typename Epilogue::OutputOp;
54  using ThreadblockSwizzle = ThreadblockSwizzle_;
55 
57  using WarpCount = typename Mma::WarpCount;
58  static int const kThreadCount = 32 * WarpCount::kCount;
59 
61  struct Params {
64  typename Mma::IteratorA::Params params_A;
65  typename Mma::IteratorA::TensorRef ref_A;
66  int64_t stride_A;
67  typename Mma::IteratorB::Params params_B;
68  typename Mma::IteratorB::TensorRef ref_B;
69  int64_t stride_B;
70  typename Epilogue::OutputTileIterator::Params params_C;
71  typename Epilogue::OutputTileIterator::TensorRef ref_C;
72  int64_t stride_C;
73  typename Epilogue::OutputTileIterator::Params params_D;
74  typename Epilogue::OutputTileIterator::TensorRef ref_D;
75  int64_t stride_D;
76  typename OutputOp::Params epilogue;
79 
80  //
81  // Methods
82  //
83 
85  Params() { }
86 
89  cutlass::gemm::GemmCoord const & problem_size_,
90  cutlass::gemm::GemmCoord const & grid_tiled_shape_,
91  typename Mma::IteratorA::TensorRef ref_A_,
92  int64_t stride_A_,
93  typename Mma::IteratorB::TensorRef ref_B_,
94  int64_t stride_B_,
95  typename Epilogue::OutputTileIterator::TensorRef ref_C_,
96  int64_t stride_C_,
97  typename Epilogue::OutputTileIterator::TensorRef ref_D_,
98  int64_t stride_D_,
99  typename OutputOp::Params epilogue_,
100  int batch_count_
101  ):
102  problem_size(problem_size_),
103  grid_tiled_shape(grid_tiled_shape_),
104  params_A(ref_A_.layout()),
105  ref_A(ref_A_),
106  stride_A(stride_A_),
107  params_B(ref_B_.layout()),
108  ref_B(ref_B_),
109  stride_B(stride_B_),
110  params_C(ref_C_.layout()),
111  ref_C(ref_C_),
112  stride_C(stride_C_),
113  params_D(ref_D_.layout()),
114  ref_D(ref_D_),
115  stride_D(stride_D_),
116  epilogue(epilogue_),
117  batch_count(batch_count_),
118  gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) {
119 
120  }
121  };
122 
125  typename Mma::SharedStorage main_loop;
126  typename Epilogue::SharedStorage epilogue;
127  };
128 
129  //
130  // Methods
131  //
132 
134  GemmBatched() { }
135 
137  CUTLASS_DEVICE
138  void operator()(Params const &params, SharedStorage &shared_storage) {
139 
140  // Compute threadblock location
141  ThreadblockSwizzle threadblock_swizzle;
142 
143  cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
144 
145  // Early exit if CTA is out of range
146  if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
147  params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
148 
149  return;
150  }
151 
152 
153  // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension
154  for (int batch_idx = threadblock_swizzle.get_batch_idx();
155  batch_idx < params.batch_count;
156  batch_idx += gridDim.z) {
157 
158  // Compute initial location in logical coordinates
159  cutlass::MatrixCoord tb_offset_A{
160  threadblock_tile_offset.m() * Mma::Shape::kM,
161  0
162  };
163 
164  cutlass::MatrixCoord tb_offset_B{
165  0,
166  threadblock_tile_offset.n() * Mma::Shape::kN
167  };
168 
169  // Compute position within threadblock
170  int thread_idx = threadIdx.x;
171 
172  // Construct iterators to A and B operands
173  typename Mma::IteratorA iterator_A(
174  params.params_A,
175  params.ref_A.data(),
176  params.problem_size.mk(),
177  thread_idx,
178  tb_offset_A);
179 
180  iterator_A.add_pointer_offset(params.stride_A * batch_idx);
181 
182  typename Mma::IteratorB iterator_B(
183  params.params_B,
184  params.ref_B.data(),
185  params.problem_size.kn(),
186  thread_idx,
187  tb_offset_B);
188 
189  iterator_B.add_pointer_offset(params.stride_B * batch_idx);
190 
191 
192  //
193  // Main loop
194  //
195 
196  // Construct thread-scoped matrix multiply
197  int warp_idx = threadIdx.x / 32;
198  int lane_idx = threadIdx.x % 32;
199 
200  Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
201 
202  typename Mma::FragmentC accumulators;
203 
204  accumulators.clear();
205 
206 
207  // Compute threadblock-scoped matrix multiply-add
208  mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
209 
210  //
211  // Epilogue
212  //
213 
214  OutputOp output_op(params.epilogue);
215 
216  //
217  // Masked tile iterators constructed from members
218  //
219 
220  threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
221 
222  //assume identity swizzle
223  MatrixCoord threadblock_offset(
224  threadblock_tile_offset.m() * Mma::Shape::kM,
225  threadblock_tile_offset.n() * Mma::Shape::kN
226  );
227 
228  // Tile iterator writing to output tile
229  typename Epilogue::OutputTileIterator iterator_C(
230  params.params_C,
231  params.ref_C.data(),
232  params.problem_size.mn(),
233  thread_idx,
234  threadblock_offset
235  );
236 
237  iterator_C.add_pointer_offset(params.stride_C * batch_idx);
238 
239  // Tile iterator writing to output tile
240  typename Epilogue::OutputTileIterator iterator_D(
241  params.params_D,
242  params.ref_D.data(),
243  params.problem_size.mn(),
244  thread_idx,
245  threadblock_offset
246  );
247 
248  iterator_D.add_pointer_offset(params.stride_D * batch_idx);
249 
251  shared_storage.epilogue,
252  thread_idx,
253  warp_idx,
254  lane_idx);
255 
256  // run efficient epilogue
257  epilogue(output_op, iterator_D, accumulators, iterator_C);
258  }
259  }
260 };
261 
263 
264 } // namespace kernel
265 } // namespace gemm
266 } // namespace cutlass
267 
CUTLASS_DEVICE void operator()(Params const &params, SharedStorage &shared_storage)
Executes one GEMM.
Definition: kernel/gemm_batched.h:138
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_batched.h:85
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_batched.h:53
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
Defines common types used for all GEMM-like operators.
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
int gemm_k_iterations
Definition: kernel/gemm_batched.h:78
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
CUTLASS_HOST_DEVICE GemmBatched()
Definition: kernel/gemm_batched.h:134
Epilogue_ Epilogue
Definition: kernel/gemm_batched.h:52
Shared memory storage structure.
Definition: kernel/gemm_batched.h:124
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
Mma::SharedStorage main_loop
Definition: kernel/gemm_batched.h:125
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
Parameters structure.
Definition: kernel/gemm_batched.h:61
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: kernel/gemm_batched.h:57
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_batched.h:73
Epilogue::OutputTileIterator::Params params_C
Definition: kernel/gemm_batched.h:70
OutputOp::Params epilogue
Definition: kernel/gemm_batched.h:76
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int64_t stride_C
Definition: kernel/gemm_batched.h:72
CUTLASS_HOST_DEVICE Coord< 2 > mk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:177
Mma_ Mma
Definition: kernel/gemm_batched.h:51
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
Mma::IteratorA::Params params_A
Definition: kernel/gemm_batched.h:64
Defines a canonical coordinate for rank=2 matrices offering named indices.
int batch_count
Definition: kernel/gemm_batched.h:77
Mma::IteratorB::Params params_B
Definition: kernel/gemm_batched.h:67
CUTLASS_HOST_DEVICE Coord< 2 > kn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:195
int64_t stride_B
Definition: kernel/gemm_batched.h:69
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size_, cutlass::gemm::GemmCoord const &grid_tiled_shape_, typename Mma::IteratorA::TensorRef ref_A_, int64_t stride_A_, typename Mma::IteratorB::TensorRef ref_B_, int64_t stride_B_, typename Epilogue::OutputTileIterator::TensorRef ref_C_, int64_t stride_C_, typename Epilogue::OutputTileIterator::TensorRef ref_D_, int64_t stride_D_, typename OutputOp::Params epilogue_, int batch_count_)
Definition: kernel/gemm_batched.h:88
int64_t stride_A
Definition: kernel/gemm_batched.h:66
Definition: kernel/gemm_batched.h:49
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_batched.h:126
int64_t stride_D
Definition: kernel/gemm_batched.h:75
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_batched.h:54
Basic include for CUTLASS.
Definition: matrix_coord.h:39