40 namespace threadblock {
113 int split_k_slices)
const {
116 (problem_size.
m() + tile_size.
m() - 1) / tile_size.
m(),
117 (problem_size.
n() + tile_size.
n() - 1) / tile_size.
n(),
124 return dim3(tiled_shape.
m() *
kTile, (tiled_shape.
n() + kTile - 1) / kTile, tiled_shape.
k());
135 (block_idx_x /
kTile),
136 (block_idx_y * kTile) + (block_idx_x %
kTile),
155 int split_k_slices)
const {
158 (problem_size.
m() + tile_size.
m() - 1) / tile_size.
m(),
159 (problem_size.
n() + tile_size.
n() - 1) / tile_size.
n(),
166 return dim3(tiled_shape.
n(), tiled_shape.
m(), tiled_shape.
k());
193 (problem_size.
m() + tile_size.
m() - 1) / tile_size.
m(),
194 (problem_size.
n() + tile_size.
n() - 1) / tile_size.
n(),
195 batch_count % (1 << 16));
201 return dim3(tiled_shape.
m(), tiled_shape.
n(), tiled_shape.
k());
231 int partitions)
const {
234 (problem_size.
m() + tile_size.
m() - 1) / tile_size.
m(),
235 (problem_size.
n() + tile_size.
n() - 1) / tile_size.
n(),
242 return dim3(tiled_shape.
m(), tiled_shape.
n(), tiled_shape.
k());
267 int partitions)
const {
270 (problem_size.
m() + tile_size.
m() - 1) / tile_size.
m(),
271 (problem_size.
n() + tile_size.
n() - 1) / tile_size.
n(),
278 return dim3(tiled_shape.
n(), tiled_shape.
m(), tiled_shape.
k());
306 (problem_size.
n() + tile_size.
n() - 1) / tile_size.
n(),
307 (problem_size.
k() + tile_size.
k() - 1) / tile_size.
k(),
314 return dim3(tiled_shape.
n(), tiled_shape.
batch(), tiled_shape.
k());
int const kTile
Definition: gemm/threadblock/threadblock_swizzle.h:106
Definition: aligned_buffer.h:35
CUTLASS_DEVICE int RematerializeThreadIdxY()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:52
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:200
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:206
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:241
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:264
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
CUTLASS_DEVICE int RematerializeBlockDimX()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:82
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:123
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:152
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:165
CUTLASS_HOST_DEVICE dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:313
CUTLASS_DEVICE int RematerializeThreadIdxX()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:46
Threadblock swizzling function for GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:145
Definition: include/cutlass/gemm/gemm.h:260
CUTLASS_HOST_DEVICE GemmHorizontalThreadblockSwizzle()
Definition: gemm/threadblock/threadblock_swizzle.h:148
CUTLASS_DEVICE int get_batch_tile_idx() const
Gets the batch tile index.
Definition: gemm/threadblock/threadblock_swizzle.h:330
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:277
CUTLASS_DEVICE int RematerializeThreadIdxZ()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:58
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:322
CUTLASS_DEVICE int RematerializeBlockDimY()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:88
CUTLASS_HOST_DEVICE GemmIdentityThreadblockSwizzle()
Definition: gemm/threadblock/threadblock_swizzle.h:104
CUTLASS_DEVICE int get_batch_idx() const
Gets the absolute batch index.
Definition: gemm/threadblock/threadblock_swizzle.h:336
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:110
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:129
Threadblock swizzling function for split-K GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:260
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, int batch_count, GemmCoord tile_size) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:187
CUTLASS_DEVICE int RematerializeBlockIdxY()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:70
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:228
CUTLASS_HOST_DEVICE BatchedGemmCoord get_tiled_shape(BatchedGemmCoord problem_size, BatchedGemmCoord tile_size) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:300
CUTLASS_DEVICE int RematerializeBlockDimZ()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:94
CUTLASS_DEVICE BatchedGemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:319
Threadblock swizzling function for GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:101
Threadblock swizzling function for batched GEMVs.
Definition: gemm/threadblock/threadblock_swizzle.h:296
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
CUTLASS_DEVICE int get_batch_idx() const
Gets the batch index.
Definition: gemm/threadblock/threadblock_swizzle.h:216
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
CUTLASS_DEVICE int RematerializeBlockIdxZ()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:76
Threadblock swizzling function for batched GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:183
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:284
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:171
Threadblock swizzling function for split-K GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:224
Basic include for CUTLASS.
CUTLASS_DEVICE int RematerializeBlockIdxX()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:64
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:248