45 namespace threadblock {
59 static int const kRow = Row;
64 static int const kCount = kColumn * kRow * kGroup * kCluster *
kTile;
82 static int const kThreads = ThreadMap::kThreads;
85 static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
107 Index cluster = coord.
strided() / (Shape::kGroup * Shape::kRow);
108 Index cluster_residual = coord.
strided() % (Shape::kGroup * Shape::kRow);
110 Index group = cluster_residual / (Shape::kRow);
111 Index row = cluster_residual % (Shape::kRow);
114 row + group * Shape::kRow * Count::kRow
115 + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow,
129 int ElementsPerAccess,
139 int ElementsPerAccess,
142 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
143 static int const kWarpSize = 32;
144 static int const kElementsPerAccess = ElementsPerAccess;
145 static int const kElementSize = ElementSize;
147 static int const kIterationsRow = 1;
148 static int const kDeltaRow = 1;
149 static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
150 static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
152 static int const kAccessWidth = kWarpSize;
153 static int const kAccessRows = 1;
154 static int const kWarpPartitionsRow = 1;
155 static int const kWarpPartitionsColumn = WarpsRemaining;
162 int ElementsPerAccess,
165 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
167 static int const kMemoryAccessSize = 128;
168 static int const kWarpSize = 32;
170 static int const kElementsPerAccess = ElementsPerAccess;
171 static int const kElementSize = ElementSize;
174 static int const kShapeRow = Shape::kRow / WarpsRemaining;
175 static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
177 static int const kTargetMemoryAccessWidth =
178 kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
180 static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
183 static int const kAccessWidth =
184 (Detail::kTargetAccessRows > Detail::kShapeRow ?
185 kWarpSize / Detail::kShapeRow
188 const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
191 static int const kAccessRows =
192 (Detail::kTargetAccessRows > Detail::kShapeRow ?
194 :
const_min(Shape::kRow, kWarpSize / kAccessWidth));
196 static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
197 static int const kDeltaRow = kAccessRows;
199 static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
200 static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
202 static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn,
"Accessing too many elements per access");
203 static_assert( kIterationsColumn > 0,
"Iteration Count Column must be > 0" );
204 static_assert( kIterationsRow > 0,
"Iteration Count Row must be > 0" );
206 static int const kWarpPartitionsRow = 1;
207 static int const kWarpPartitionsColumn = 1;
225 int ElementsPerAccess,
233 static int const kWarpSize = 32;
234 static int const kThreads = Threads;
235 static int const kWarpCount = kThreads / kWarpSize;
237 static int const kElementsPerAccess = ElementsPerAccess;
238 static int const kElementSize = ElementSize;
247 static int const kIterationsCluster =
248 ((Shape::kCluster > kWarpCount) ?
249 Shape::kCluster / kWarpCount
252 static int const kDeltaCluster =
253 ((Shape::kCluster > kWarpCount) ?
254 Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
257 static int const kCompactedDeltaCluster =
258 ((Shape::kCluster > kWarpCount) ?
259 Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
262 static int const kWarpPartitionsCluster =
263 ((Shape::kCluster > kWarpCount) ?
265 : kWarpCount / Shape::kCluster);
267 static int const kWarpsRemainingForGroups =
268 ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
271 static int const kIterationsGroup =
272 ((Shape::kGroup > kWarpsRemainingForGroups) ?
273 Shape::kGroup / kWarpsRemainingForGroups
276 static int const kDeltaGroup =
277 ((Shape::kGroup > kWarpsRemainingForGroups) ?
278 Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
281 static int const kCompactedDeltaGroup =
282 ((Shape::kGroup > kWarpsRemainingForGroups) ?
283 Shape::kRow * Shape::kGroup / kIterationsGroup
286 static int const kWarpPartitionsGroup =
287 ((Shape::kGroup > kWarpsRemainingForGroups) ?
289 : kWarpsRemainingForGroups / Shape::kGroup);
291 static int const kWarpsRemainingForRows =
292 ((Shape::kGroup > kWarpsRemainingForGroups) ?
294 : kWarpsRemainingForGroups / Shape::kGroup);
299 kWarpsRemainingForRows,
302 (Shape::kRow > kWarpsRemainingForRows)
307 RowArrangement::kWarpPartitionsColumn,
308 RowArrangement::kWarpPartitionsRow,
309 kWarpPartitionsGroup,
310 kWarpPartitionsCluster,
313 static int const kAccessWidth = RowArrangement::kAccessWidth;
314 static int const kAccessRows = RowArrangement::kAccessRows;
322 Detail::RowArrangement::kIterationsColumn,
323 Detail::RowArrangement::kIterationsRow,
324 Detail::kIterationsGroup,
325 Detail::kIterationsCluster,
329 Detail::RowArrangement::kDeltaColumn,
330 Detail::RowArrangement::kDeltaRow,
332 Detail::kDeltaCluster,
339 int warp_idx = thread_idx / kWarpSize;
340 int lane_idx = thread_idx % kWarpSize;
343 int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
344 int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
346 int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
347 int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
349 int row_idx = residual_group / Detail::WarpPartitions::kRow;
350 int col_idx = residual_group % Detail::WarpPartitions::kRow;
353 int lane_row_offset = lane_idx / Detail::kAccessWidth;
354 int lane_col_offset = lane_idx % Detail::kAccessWidth;
357 int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
358 int group_offset = group_idx * Shape::kRow * Count::kRow;
359 int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
360 int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
363 cluster_offset + group_offset + row_offset + lane_row_offset,
364 (column_offset + lane_col_offset) * kElementsPerAccess
375 Detail::RowArrangement::kIterationsColumn,
376 Detail::RowArrangement::kIterationsRow,
377 Detail::kIterationsGroup,
378 Detail::kIterationsCluster,
382 Detail::RowArrangement::kDeltaColumn,
383 Detail::RowArrangement::kDeltaRow,
384 Detail::kCompactedDeltaGroup,
385 Detail::kCompactedDeltaCluster,
389 static int const kElementsPerAccess = ElementsPerAccess;
392 static int const kThreads = Threads;
398 int warp_idx = thread_idx / kWarpSize;
399 int lane_idx = thread_idx % kWarpSize;
402 int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
403 int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
405 int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
406 int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
408 int row_idx = residual_group / Detail::WarpPartitions::kRow;
409 int col_idx = residual_group % Detail::WarpPartitions::kRow;
412 int lane_row_offset = lane_idx / Detail::kAccessWidth;
413 int lane_col_offset = lane_idx % Detail::kAccessWidth;
416 int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup;
417 int group_offset = group_idx * Shape::kRow;
418 int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
419 int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
422 cluster_offset + group_offset + row_offset + lane_row_offset,
423 (column_offset + lane_col_offset) * kElementsPerAccess
440 template <
typename WarpCount_,
typename MmaCount_,
int Threads,
441 int ElementsPerAccess,
int ElementSize>
446 static int const kWarpSize = 32;
447 static int const kThreads = Threads;
448 static int const kWarpCount = kThreads / kWarpSize;
450 static int const kElementsPerAccess = ElementsPerAccess;
451 static int const kElementSize = ElementSize;
470 int warp_idx = thread_idx / kWarpSize;
471 int lane_idx = thread_idx % kWarpSize;
475 Delta::kContiguous * Iterations::kContiguous,
476 Delta::kStrided * Iterations::kStrided};
479 warp_idx / WarpCount::kContiguous};
483 lane_idx * kElementsPerAccess, 0};
486 warp_footprint * warp_offset + thread_offset_in_warp;
488 return thread_offset_in_threadblock_tile;
int Index
Integer-valued index.
Definition: pitch_linear.h:56
ThreadMap_ ThreadMap
Conventional thread map (concept: ThreadMap)
Definition: output_tile_thread_map.h:79
Definition: output_tile_thread_map.h:228
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
Count_ Count
Definition: output_tile_thread_map.h:231
static int const kGroup
Definition: output_tile_thread_map.h:60
Tuple defining point in output tile.
Definition: output_tile_thread_map.h:57
WarpCount_ WarpCount
Definition: output_tile_thread_map.h:443
Iterations_ Iterations
Iterations performed by each thread.
Definition: output_tile_thread_map.h:91
static int const kColumn
Definition: output_tile_thread_map.h:58
RowArrangement determines how one or more warps cover a region of consecutive rows.
Definition: output_tile_thread_map.h:133
Definition: output_tile_thread_map.h:442
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Compacted thread map in which the 4D region is contiguous.
Definition: output_tile_thread_map.h:369
Count_ Count
Number of iterator iterations.
Definition: output_tile_thread_map.h:97
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Function to compute each thread's initial offset.
Definition: output_tile_thread_map.h:396
Defines a Shape template for matrix tiles.
Shape_ Shape
Definition: output_tile_thread_map.h:230
static CUTLASS_HOST_DEVICE layout::PitchLinearCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:469
detail::RowArrangement< Shape, kWarpsRemainingForRows, kElementsPerAccess, kElementSize,(Shape::kRow > kWarpsRemainingForRows) > RowArrangement
Definition: output_tile_thread_map.h:303
MmaCount Iterations
Definition: output_tile_thread_map.h:463
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
Shape_ Shape
Definition: output_tile_thread_map.h:372
Delta_ Delta
Delta between accesses.
Definition: output_tile_thread_map.h:94
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:101
Definition: output_tile_thread_map.h:457
static int const kRow
Definition: output_tile_thread_map.h:59
Defines layout functions used by TensorRef and derived classes.
Definition: output_tile_thread_map.h:76
Shape_ Shape
Shape of the tile.
Definition: output_tile_thread_map.h:88
static int const kTile
Definition: output_tile_thread_map.h:62
static int const kCount
Definition: output_tile_thread_map.h:64
MmaCount_ MmaCount
Definition: output_tile_thread_map.h:444
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:337
CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)
Definition: fast_math.h:219
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Definition: output_tile_thread_map.h:244
static int const kCluster
Definition: output_tile_thread_map.h:61
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97