CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
output_tile_thread_map.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
31 #pragma once
32 
33 #include "cutlass/cutlass.h"
34 #include "cutlass/numeric_types.h"
35 #include "cutlass/array.h"
36 #include "cutlass/layout/matrix.h"
37 #include "cutlass/matrix_shape.h"
38 #include "cutlass/tensor_ref.h"
39 #include "cutlass/fast_math.h"
40 
42 
43 namespace cutlass {
44 namespace epilogue {
45 namespace threadblock {
46 
48 
50 template <
51  int Column,
52  int Row,
53  int Group,
54  int Cluster,
55  int Tile
56 >
58  static int const kColumn = Column;
59  static int const kRow = Row;
60  static int const kGroup = Group;
61  static int const kCluster = Cluster;
62  static int const kTile = Tile;
63 
64  static int const kCount = kColumn * kRow * kGroup * kCluster * kTile;
65 };
66 
68 
69 template <
70  typename ThreadMap_,
71  typename Shape_,
72  typename Iterations_,
73  typename Delta_,
74  typename Count_
75 >
77 
79  using ThreadMap = ThreadMap_;
80 
82  static int const kThreads = ThreadMap::kThreads;
83 
85  static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
86 
88  using Shape = Shape_;
89 
91  using Iterations = Iterations_;
92 
94  using Delta = Delta_;
95 
97  using Count = Count_;
98 
101  static MatrixCoord initial_offset(int thread_idx) {
102 
103  using Index = typename layout::PitchLinearCoord::Index;
104 
105  layout::PitchLinearCoord coord = ThreadMap::initial_offset(thread_idx);
106 
107  Index cluster = coord.strided() / (Shape::kGroup * Shape::kRow);
108  Index cluster_residual = coord.strided() % (Shape::kGroup * Shape::kRow);
109 
110  Index group = cluster_residual / (Shape::kRow);
111  Index row = cluster_residual % (Shape::kRow);
112 
113  return MatrixCoord{
114  row + group * Shape::kRow * Count::kRow
115  + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow,
116  coord.contiguous()
117  };
118  }
119 };
120 
122 
123 namespace detail {
124 
126 template <
127  typename Shape,
128  int WarpsRemaining,
129  int ElementsPerAccess,
130  int ElementSize,
131  bool Is2dTile
132 >
134 
136 template <
137  typename Shape,
138  int WarpsRemaining,
139  int ElementsPerAccess,
140  int ElementSize
141 >
142 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
143  static int const kWarpSize = 32;
144  static int const kElementsPerAccess = ElementsPerAccess;
145  static int const kElementSize = ElementSize;
146 
147  static int const kIterationsRow = 1;
148  static int const kDeltaRow = 1;
149  static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
150  static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
151 
152  static int const kAccessWidth = kWarpSize;
153  static int const kAccessRows = 1;
154  static int const kWarpPartitionsRow = 1;
155  static int const kWarpPartitionsColumn = WarpsRemaining;
156 };
157 
159 template <
160  typename Shape,
161  int WarpsRemaining,
162  int ElementsPerAccess,
163  int ElementSize
164 >
165 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
166 
167  static int const kMemoryAccessSize = 128;
168  static int const kWarpSize = 32;
169 
170  static int const kElementsPerAccess = ElementsPerAccess;
171  static int const kElementSize = ElementSize;
172 
173  struct Detail {
174  static int const kShapeRow = Shape::kRow / WarpsRemaining;
175  static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
176 
177  static int const kTargetMemoryAccessWidth =
178  kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
179 
180  static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
181  };
182 
183  static int const kAccessWidth =
184  (Detail::kTargetAccessRows > Detail::kShapeRow ?
185  kWarpSize / Detail::kShapeRow
186  : const_min(
187  Detail::kShapeWidth,
188  const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
189  ));
190 
191  static int const kAccessRows =
192  (Detail::kTargetAccessRows > Detail::kShapeRow ?
193  Detail::kShapeRow
194  : const_min(Shape::kRow, kWarpSize / kAccessWidth));
195 
196  static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
197  static int const kDeltaRow = kAccessRows;
198 
199  static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
200  static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
201 
202  static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");
203  static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );
204  static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );
205 
206  static int const kWarpPartitionsRow = 1;
207  static int const kWarpPartitionsColumn = 1;
208 };
209 
210 }
211 
213 
221 template <
222  typename Shape_,
223  typename Count_,
224  int Threads,
225  int ElementsPerAccess,
226  int ElementSize
227 >
229 
230  using Shape = Shape_;
231  using Count = Count_;
232 
233  static int const kWarpSize = 32;
234  static int const kThreads = Threads;
235  static int const kWarpCount = kThreads / kWarpSize;
236 
237  static int const kElementsPerAccess = ElementsPerAccess;
238  static int const kElementSize = ElementSize;
239 
240  //
241  // Metaprogram computation
242  //
243 
244  struct Detail {
245 
246  // Clusters
247  static int const kIterationsCluster =
248  ((Shape::kCluster > kWarpCount) ?
249  Shape::kCluster / kWarpCount
250  : 1);
251 
252  static int const kDeltaCluster =
253  ((Shape::kCluster > kWarpCount) ?
254  Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
255  : 1);
256 
257  static int const kCompactedDeltaCluster =
258  ((Shape::kCluster > kWarpCount) ?
259  Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
260  : 1);
261 
262  static int const kWarpPartitionsCluster =
263  ((Shape::kCluster > kWarpCount) ?
264  kWarpCount
265  : kWarpCount / Shape::kCluster);
266 
267  static int const kWarpsRemainingForGroups =
268  ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
269 
270  // Groups
271  static int const kIterationsGroup =
272  ((Shape::kGroup > kWarpsRemainingForGroups) ?
273  Shape::kGroup / kWarpsRemainingForGroups
274  : 1);
275 
276  static int const kDeltaGroup =
277  ((Shape::kGroup > kWarpsRemainingForGroups) ?
278  Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
279  : 1);
280 
281  static int const kCompactedDeltaGroup =
282  ((Shape::kGroup > kWarpsRemainingForGroups) ?
283  Shape::kRow * Shape::kGroup / kIterationsGroup
284  : 1);
285 
286  static int const kWarpPartitionsGroup =
287  ((Shape::kGroup > kWarpsRemainingForGroups) ?
288  1
289  : kWarpsRemainingForGroups / Shape::kGroup);
290 
291  static int const kWarpsRemainingForRows =
292  ((Shape::kGroup > kWarpsRemainingForGroups) ?
293  1
294  : kWarpsRemainingForGroups / Shape::kGroup);
295 
296  // Rows
298  Shape,
299  kWarpsRemainingForRows,
300  kElementsPerAccess,
301  kElementSize,
302  (Shape::kRow > kWarpsRemainingForRows)
303  >;
304 
305  // Warp partitions
307  RowArrangement::kWarpPartitionsColumn,
308  RowArrangement::kWarpPartitionsRow,
309  kWarpPartitionsGroup,
310  kWarpPartitionsCluster,
311  1>;
312 
313  static int const kAccessWidth = RowArrangement::kAccessWidth;
314  static int const kAccessRows = RowArrangement::kAccessRows;
315  };
316 
317  //
318  // Output
319  //
320 
321  using Iterations = OutputTileShape<
322  Detail::RowArrangement::kIterationsColumn,
323  Detail::RowArrangement::kIterationsRow,
324  Detail::kIterationsGroup,
325  Detail::kIterationsCluster,
326  1>;
327 
328  using Delta = OutputTileShape<
329  Detail::RowArrangement::kDeltaColumn,
330  Detail::RowArrangement::kDeltaRow,
331  Detail::kDeltaGroup,
332  Detail::kDeltaCluster,
333  1>;
334 
337  static MatrixCoord initial_offset(int thread_idx) {
338 
339  int warp_idx = thread_idx / kWarpSize;
340  int lane_idx = thread_idx % kWarpSize;
341 
342  // Compute warp location
343  int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
344  int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
345 
346  int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
347  int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
348 
349  int row_idx = residual_group / Detail::WarpPartitions::kRow;
350  int col_idx = residual_group % Detail::WarpPartitions::kRow;
351 
352  // Compute per-lane offset
353  int lane_row_offset = lane_idx / Detail::kAccessWidth;
354  int lane_col_offset = lane_idx % Detail::kAccessWidth;
355 
356  // Compute coordinate in output space
357  int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
358  int group_offset = group_idx * Shape::kRow * Count::kRow;
359  int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
360  int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
361 
362  return MatrixCoord(
363  cluster_offset + group_offset + row_offset + lane_row_offset,
364  (column_offset + lane_col_offset) * kElementsPerAccess
365  );
366  }
367 
370 
371 
372  using Shape = Shape_;
373 
374  using Iterations = OutputTileShape<
375  Detail::RowArrangement::kIterationsColumn,
376  Detail::RowArrangement::kIterationsRow,
377  Detail::kIterationsGroup,
378  Detail::kIterationsCluster,
379  1>;
380 
381  using Delta = OutputTileShape<
382  Detail::RowArrangement::kDeltaColumn,
383  Detail::RowArrangement::kDeltaRow,
384  Detail::kCompactedDeltaGroup,
385  Detail::kCompactedDeltaCluster,
386  1>;
387 
389  static int const kElementsPerAccess = ElementsPerAccess;
390 
392  static int const kThreads = Threads;
393 
396  static MatrixCoord initial_offset(int thread_idx) {
397 
398  int warp_idx = thread_idx / kWarpSize;
399  int lane_idx = thread_idx % kWarpSize;
400 
401  // Compute warp location
402  int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
403  int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
404 
405  int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
406  int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
407 
408  int row_idx = residual_group / Detail::WarpPartitions::kRow;
409  int col_idx = residual_group % Detail::WarpPartitions::kRow;
410 
411  // Compute per-lane offset
412  int lane_row_offset = lane_idx / Detail::kAccessWidth;
413  int lane_col_offset = lane_idx % Detail::kAccessWidth;
414 
415  // Compute coordinate in output space
416  int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup;
417  int group_offset = group_idx * Shape::kRow;
418  int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
419  int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
420 
421  MatrixCoord coord(
422  cluster_offset + group_offset + row_offset + lane_row_offset,
423  (column_offset + lane_col_offset) * kElementsPerAccess
424  );
425 
426  return coord;
427  }
428  };
429 };
430 
432 
440 template <typename WarpCount_, typename MmaCount_, int Threads,
441  int ElementsPerAccess, int ElementSize>
443  using WarpCount = WarpCount_;
444  using MmaCount = MmaCount_;
445 
446  static int const kWarpSize = 32;
447  static int const kThreads = Threads;
448  static int const kWarpCount = kThreads / kWarpSize;
449 
450  static int const kElementsPerAccess = ElementsPerAccess;
451  static int const kElementSize = ElementSize;
452 
453  //
454  // Metaprogram computation
455  //
456 
457  struct Detail {};
458 
459  //
460  // Output
461  //
462 
464 
466 
469  static layout::PitchLinearCoord initial_offset(int thread_idx) {
470  int warp_idx = thread_idx / kWarpSize;
471  int lane_idx = thread_idx % kWarpSize;
472 
473  // Compute warp location
474  layout::PitchLinearCoord warp_footprint{
475  Delta::kContiguous * Iterations::kContiguous,
476  Delta::kStrided * Iterations::kStrided};
477 
478  layout::PitchLinearCoord warp_offset{warp_idx % WarpCount::kContiguous,
479  warp_idx / WarpCount::kContiguous};
480 
481  // Compute per-lane offset
482  layout::PitchLinearCoord thread_offset_in_warp{
483  lane_idx * kElementsPerAccess, 0};
484 
485  layout::PitchLinearCoord thread_offset_in_threadblock_tile =
486  warp_footprint * warp_offset + thread_offset_in_warp;
487 
488  return thread_offset_in_threadblock_tile;
489  }
490 };
491 
493 
494 } // namespace threadblock
495 } // namespace epilogue
496 } // namespace cutlass
int Index
Integer-valued index.
Definition: pitch_linear.h:56
ThreadMap_ ThreadMap
Conventional thread map (concept: ThreadMap)
Definition: output_tile_thread_map.h:79
Definition: output_tile_thread_map.h:228
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
Count_ Count
Definition: output_tile_thread_map.h:231
static int const kGroup
Definition: output_tile_thread_map.h:60
Tuple defining point in output tile.
Definition: output_tile_thread_map.h:57
WarpCount_ WarpCount
Definition: output_tile_thread_map.h:443
Iterations_ Iterations
Iterations performed by each thread.
Definition: output_tile_thread_map.h:91
static int const kColumn
Definition: output_tile_thread_map.h:58
RowArrangement determines how one or more warps cover a region of consecutive rows.
Definition: output_tile_thread_map.h:133
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Compacted thread map in which the 4D region is contiguous.
Definition: output_tile_thread_map.h:369
Count_ Count
Number of iterator iterations.
Definition: output_tile_thread_map.h:97
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Function to compute each thread&#39;s initial offset.
Definition: output_tile_thread_map.h:396
Defines a Shape template for matrix tiles.
Shape_ Shape
Definition: output_tile_thread_map.h:230
static CUTLASS_HOST_DEVICE layout::PitchLinearCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:469
detail::RowArrangement< Shape, kWarpsRemainingForRows, kElementsPerAccess, kElementSize,(Shape::kRow > kWarpsRemainingForRows) > RowArrangement
Definition: output_tile_thread_map.h:303
MmaCount Iterations
Definition: output_tile_thread_map.h:463
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
#define static_assert(__e, __m)
Definition: platform.h:153
Delta_ Delta
Delta between accesses.
Definition: output_tile_thread_map.h:94
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:101
static int const kRow
Definition: output_tile_thread_map.h:59
Defines layout functions used by TensorRef and derived classes.
Math utilities.
Definition: output_tile_thread_map.h:76
Shape_ Shape
Shape of the tile.
Definition: output_tile_thread_map.h:88
static int const kTile
Definition: output_tile_thread_map.h:62
static int const kCount
Definition: output_tile_thread_map.h:64
MmaCount_ MmaCount
Definition: output_tile_thread_map.h:444
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:337
CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)
Definition: fast_math.h:219
Basic include for CUTLASS.
Definition: matrix_coord.h:39
static int const kCluster
Definition: output_tile_thread_map.h:61
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97