CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm/threadblock/threadblock_swizzle.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 
34 #include "cutlass/gemm/gemm.h"
35 
37 
38 namespace cutlass {
39 namespace gemm {
40 namespace threadblock {
41 
43 
45 CUTLASS_DEVICE
47  return threadIdx.x;
48 }
49 
51 CUTLASS_DEVICE
53  return threadIdx.y;
54 }
55 
57 CUTLASS_DEVICE
59  return threadIdx.z;
60 }
61 
63 CUTLASS_DEVICE
65  return blockIdx.x;
66 }
67 
69 CUTLASS_DEVICE
71  return blockIdx.y;
72 }
73 
75 CUTLASS_DEVICE
77  return blockIdx.z;
78 }
79 
81 CUTLASS_DEVICE
83  return blockDim.x;
84 }
85 
87 CUTLASS_DEVICE
89  return blockDim.y;
90 }
91 
93 CUTLASS_DEVICE
95  return blockDim.z;
96 }
97 
99 
102 
105 
106  int const kTile = 1;
107 
111  GemmCoord problem_size,
112  GemmCoord tile_size,
113  int split_k_slices) const {
114 
115  return GemmCoord(
116  (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
117  (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
118  split_k_slices);
119  }
120 
123  dim3 get_grid_shape(GemmCoord tiled_shape) const {
124  return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
125  }
126 
128  CUTLASS_DEVICE
130 
131  int block_idx_x = RematerializeBlockIdxX();
132  int block_idx_y = RematerializeBlockIdxY();
133 
134  return GemmCoord{
135  (block_idx_x / kTile),
136  (block_idx_y * kTile) + (block_idx_x % kTile),
138  };
139  }
140 };
141 
143 
146 
149 
153  GemmCoord problem_size,
154  GemmCoord tile_size,
155  int split_k_slices) const {
156 
157  return GemmCoord(
158  (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
159  (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
160  split_k_slices);
161  }
162 
165  dim3 get_grid_shape(GemmCoord tiled_shape) const {
166  return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k());
167  }
168 
170  CUTLASS_DEVICE
172  return GemmCoord{
176  };
177  }
178 };
179 
181 
184 
188  GemmCoord problem_size,
189  int batch_count,
190  GemmCoord tile_size) const {
191 
192  return GemmCoord(
193  (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
194  (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
195  batch_count % (1 << 16));
196  }
197 
200  dim3 get_grid_shape(GemmCoord tiled_shape) const {
201  return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
202  }
203 
205  CUTLASS_DEVICE
207  return GemmCoord{
210  0
211  };
212  }
213 
215  CUTLASS_DEVICE
216  int get_batch_idx() const {
217  return RematerializeBlockIdxZ();
218  }
219 };
220 
222 
225 
229  GemmCoord problem_size,
230  GemmCoord tile_size,
231  int partitions) const {
232 
233  return GemmCoord(
234  (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
235  (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
236  partitions);
237  }
238 
241  dim3 get_grid_shape(GemmCoord tiled_shape) const {
242  return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
243  }
244 
245 
247  CUTLASS_DEVICE
249  return GemmCoord{
253  };
254  }
255 };
256 
258 
261 
265  GemmCoord problem_size,
266  GemmCoord tile_size,
267  int partitions) const {
268 
269  return GemmCoord(
270  (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
271  (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
272  partitions);
273  }
274 
277  dim3 get_grid_shape(GemmCoord tiled_shape) const {
278  return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k());
279  }
280 
281 
283  CUTLASS_DEVICE
285  return GemmCoord{
289  };
290  }
291 };
292 
294 
297 
301  BatchedGemmCoord problem_size,
302  BatchedGemmCoord tile_size) const {
303 
304  return BatchedGemmCoord(
305  1, // M is always 1
306  (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
307  (problem_size.k() + tile_size.k() - 1) / tile_size.k(),
308  (problem_size.batch() + tile_size.batch() - 1) / tile_size.batch());
309  }
310 
313  dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const {
314  return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k());
315  }
316 
318  CUTLASS_DEVICE
320  return BatchedGemmCoord{
321  0, // M is always 1
325  };
326  }
327 
329  CUTLASS_DEVICE
330  int get_batch_tile_idx() const {
331  return RematerializeBlockIdxY();
332  }
333 
335  CUTLASS_DEVICE
336  int get_batch_idx() const {
338  }
339 };
340 
342 
343 } // namespace threadblock
344 } // namespace gemm
345 } // namespace cutlass
346 
int const kTile
Definition: gemm/threadblock/threadblock_swizzle.h:106
Definition: aligned_buffer.h:35
CUTLASS_DEVICE int RematerializeThreadIdxY()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:52
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:200
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:206
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:241
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:264
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
CUTLASS_DEVICE int RematerializeBlockDimX()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:82
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:123
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:152
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:165
CUTLASS_HOST_DEVICE dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:313
CUTLASS_DEVICE int RematerializeThreadIdxX()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:46
Threadblock swizzling function for GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:145
Definition: include/cutlass/gemm/gemm.h:260
CUTLASS_HOST_DEVICE GemmHorizontalThreadblockSwizzle()
Definition: gemm/threadblock/threadblock_swizzle.h:148
CUTLASS_DEVICE int get_batch_tile_idx() const
Gets the batch tile index.
Definition: gemm/threadblock/threadblock_swizzle.h:330
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:277
CUTLASS_DEVICE int RematerializeThreadIdxZ()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:58
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:322
CUTLASS_DEVICE int RematerializeBlockDimY()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:88
CUTLASS_HOST_DEVICE GemmIdentityThreadblockSwizzle()
Definition: gemm/threadblock/threadblock_swizzle.h:104
CUTLASS_DEVICE int get_batch_idx() const
Gets the absolute batch index.
Definition: gemm/threadblock/threadblock_swizzle.h:336
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:110
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:129
Threadblock swizzling function for split-K GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:260
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, int batch_count, GemmCoord tile_size) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:187
CUTLASS_DEVICE int RematerializeBlockIdxY()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:70
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:228
CUTLASS_HOST_DEVICE BatchedGemmCoord get_tiled_shape(BatchedGemmCoord problem_size, BatchedGemmCoord tile_size) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:300
CUTLASS_DEVICE int RematerializeBlockDimZ()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:94
CUTLASS_DEVICE BatchedGemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:319
Threadblock swizzling function for GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:101
Threadblock swizzling function for batched GEMVs.
Definition: gemm/threadblock/threadblock_swizzle.h:296
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
CUTLASS_DEVICE int get_batch_idx() const
Gets the batch index.
Definition: gemm/threadblock/threadblock_swizzle.h:216
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
CUTLASS_DEVICE int RematerializeBlockIdxZ()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:76
Threadblock swizzling function for batched GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:183
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:284
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:171
Threadblock swizzling function for split-K GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:224
Basic include for CUTLASS.
CUTLASS_DEVICE int RematerializeBlockIdxX()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:64
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:248