57 int ElementsPerAccess = 1
79 static_assert(!(Shape::kContiguous % kElementsPerAccess),
"");
81 static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * kElementsPerAccess)),
82 "Shape must be divisible thread count.");
93 "Shape must be divisible by number of iterations of each thread." 104 layout::PitchLinearShape<
114 layout::PitchLinearShape<
118 layout::PitchLinearShape<
131 thread_id / Detail::ShapeVec::kContiguous);
138 int ElementsPerAccess = 1
142 static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0,
143 "Contiguous shape must divide number of threads");
147 static int const kThreads = Threads;
148 static int const kElementsPerAccess = ElementsPerAccess;
159 return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0);
166 int ElementsPerAccess = 1
171 "Strided shape must divide number of threads");
175 static int const kThreads = Threads;
176 static int const kElementsPerAccess = ElementsPerAccess;
180 Shape::kStrided / kThreads>;
190 return TensorCoord(0, thread_id * Iterations::kStrided);
202 typename WarpThreadArrangement_,
203 int ElementsPerAccess = 1
214 static int const kThreads = Threads;
217 static int const kElementsPerAccess = ElementsPerAccess;
229 static int const kWarpSize = WarpThreadArrangement::kCount;
232 static int const kWarpCount = kThreads / kWarpSize;
235 !(Shape::kContiguous % kElementsPerAccess),
236 "Shape must be divisible by vector length.");
246 ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous,
247 ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided
252 static int const kWarpsStrided =
253 (WarpAccessIterations::kStrided >= kWarpCount
255 : WarpAccessIterations::kStrided);
257 static int const kWarpsContiguous =
258 (kWarpCount > WarpAccessIterations::kStrided
259 ? kWarpCount / kWarpsStrided
264 kWarpsContiguous, kWarpsStrided
270 Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
271 Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
275 "Number of iterations must be non-zero");
280 Detail::WarpThreadArrangement::kStrided
287 int warp_id = (thread_id / Detail::kWarpSize);
288 int lane_id = (thread_id % Detail::kWarpSize);
296 Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
297 Detail::WarpThreadArrangement::kStrided * Iterations::kStrided
302 (warp_id % Detail::kWarpsContiguous),
303 (warp_id / Detail::kWarpsContiguous)
308 lane_id % Detail::WarpThreadArrangement::kContiguous,
309 lane_id / Detail::WarpThreadArrangement::kContiguous
314 warp_footprint * warp_offset + thread_offset_in_warp;
319 thread_offset_in_threadblock_tile_vec.
strided()
322 return thread_offset_in_threadblock_tile_base;
332 template <
typename ThreadMap_,
typename WarpThreadArrangement_>
341 using Shape =
typename ThreadMap::Shape;
344 static int const kThreads = ThreadMap::kThreads;
347 static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
358 static int const kWarpSize = WarpThreadArrangement::kCount;
361 static int const kWarpCount = kThreads / kWarpSize;
364 "Shape must be divisible by vector length.");
369 ThreadMap::Detail::kWarpsContiguous>;
375 ThreadMap::Iterations::kContiguous>;
377 static_assert(Iterations::kCount,
"Number of iterations must be non-zero");
383 Detail::WarpThreadArrangement::kStrided>;
391 int warp_id = (thread_id / Detail::kWarpSize);
392 int lane_id = (thread_id % Detail::kWarpSize);
401 Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
402 Detail::WarpThreadArrangement::kStrided * Iterations::kStrided};
407 (warp_id / Detail::WarpArrangement::kStrided),
408 (warp_id % Detail::WarpArrangement::kStrided)};
412 lane_id % Detail::WarpThreadArrangement::kContiguous,
413 lane_id / Detail::WarpThreadArrangement::kContiguous};
418 warp_footprint * warp_offset + thread_offset_in_warp;
424 thread_offset_in_threadblock_tile_vec.
strided()};
426 return thread_offset_in_threadblock_tile_base;
430 template <
typename ThreadMap_>
439 using Shape =
typename ThreadMap::Shape;
442 static int const kThreads = ThreadMap::kThreads;
445 static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
447 static_assert(kElementsPerAccess == 1 ,
"Simt transpose requires elements per access to be 1");
451 ThreadMap::Iterations::kContiguous>;
453 static_assert(Iterations::kCount,
"Number of iterations must be non-zero");
461 ThreadMap::Delta::kContiguous>;
470 TensorCoord coord = ThreadMap::initial_offset(thread_id);
488 typename WarpThreadArrangement_,
489 int ElementsPerAccess = 1
500 static int const kThreads = Threads;
503 static int const kElementsPerAccess = ElementsPerAccess;
515 static int const kWarpSize = WarpThreadArrangement::kCount;
518 static int const kWarpCount = kThreads / kWarpSize;
521 !(Shape::kContiguous % kElementsPerAccess),
522 "Shape must be divisible by vector length.");
532 ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous,
533 ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided
538 static int const kWarpsStrided =
539 (WarpAccessIterations::kStrided >= kWarpCount
540 ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided));
542 static int const kWarpsContiguous =
543 (kWarpCount > WarpAccessIterations::kStrided ?
544 WarpAccessIterations::kContiguous / kWarpsStrided : 1);
548 kWarpsContiguous, kWarpsStrided
554 Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
555 Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
559 "Number of iterations must be non-zero");
564 Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided
571 int warp_id = (thread_id / Detail::kWarpSize);
572 int lane_id = (thread_id % Detail::kWarpSize);
580 Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
581 Detail::WarpThreadArrangement::kStrided
586 (warp_id % Detail::kWarpsContiguous),
587 (warp_id / Detail::kWarpsContiguous)
592 lane_id % Detail::WarpThreadArrangement::kContiguous,
593 lane_id / Detail::WarpThreadArrangement::kContiguous
598 warp_footprint * warp_offset + thread_offset_in_warp;
603 thread_offset_in_threadblock_tile_vec.
strided()
606 return thread_offset_in_threadblock_tile_base;
621 typename ThreadTileShape
643 static int const kThreads = Threads;
648 static_assert(!(kElementsPerAccess % 4) ,
"kElementsPerAccess, needs to be multiple of 4 (32bits)");
658 "Shape must be divisible thread count * accesses per thread.");
669 "Shape must be divisible by number of iterations of each thread." 680 layout::PitchLinearShape<
690 layout::PitchLinearShape<
694 layout::PitchLinearShape<
712 template <
typename ThreadMap_>
721 using Shape =
typename ThreadMap::Shape;
724 static int const kThreads = ThreadMap::kThreads;
727 static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
730 static_assert(kElementsPerAccess > 1 ,
"Simt transpose requires elements per access to be 1");
734 ThreadMap::Iterations::kContiguous>;
736 static_assert(Iterations::kCount,
"Number of iterations must be non-zero");
744 ThreadMap::Delta::kContiguous>;
753 TensorCoord coord = ThreadMap::initial_offset(thread_id);
static int const kCount
Definition: pitch_linear.h:46
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Defines a structure containing strides and a pointer to tensor data.
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
static int const kStrided
Definition: pitch_linear.h:45
static int const kContiguous
Definition: pitch_linear.h:44
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97