48 namespace threadblock {
74 using Element = Element_;
76 static int const kAdvanceRank = AdvanceRank;
77 using ThreadMap = ThreadMap_;
78 static int const kAlignment = Alignment;
86 using Fragment = Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
89 "Advance rank may only be along the contiguous or strided dimensions.");
110 Index increment_strided_;
113 Index increment_advance_;
127 TensorCoord t = ThreadMap::initial_offset(thread_idx);
128 long int offset = t[0] * interleave + t[1] * ref.
stride()[0]/interleave;
129 pointer_ =
reinterpret_cast<uint8_t *
>(ref.
data() + offset);
131 stride_ = ref.
stride()[0] / interleave;
148 for (
int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
153 for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
155 int idx = c + s * ThreadMap::Iterations::kContiguous;
156 frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided];
159 if (s + 1 < ThreadMap::Iterations::kStrided) {
160 byte_pointer += increment_strided_;
168 load_with_pointer_offset(
170 tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess +
171 tile_offset.strided() * Shape::kStrided * stride_
178 load_with_pointer_offset(frag, 0);
189 for (
int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
194 for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
196 int idx = c + s * ThreadMap::Iterations::kContiguous;
197 access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx];
200 if (s + 1 < ThreadMap::Iterations::kStrided) {
201 byte_pointer += increment_strided_;
209 store_with_pointer_offset(
211 tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_
218 store_with_pointer_offset(frag, 0);
224 pointer_ += increment_advance_;
231 pointer_ -= increment_advance_;
238 pointer_ += pointer_offset;
245 (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8;
246 add_pointer_offset(offset);
265 using Element = Element_;
267 static int const kAdvanceRank = AdvanceRank;
268 using ThreadMap = ThreadMap_;
269 static int const kAlignment = Alignment;
277 using Fragment = Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
283 (kAdvanceRank == 0 ? 1 : 0),
289 "Advance rank may only be along the row or column dimensions.");
305 iterator_({ref.
data(), ref.
stride()}, thread_idx, 4) {
312 iterator_.load_with_pointer_offset(frag, pointer_offset);
318 iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()});
324 iterator_.load_with_pointer_offset(frag, 0);
330 iterator_.store_with_pointer_offset(frag, pointer_offset);
336 iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()});
342 iterator_.store_with_pointer_offset(frag, 0);
362 iterator_.add_pointer_offset(pointer_offset);
368 iterator_.add_tile_offset({coord.column(), coord.row()});
387 using Element = Element_;
389 static int const kAdvanceRank = AdvanceRank;
390 using ThreadMap = ThreadMap_;
391 static int const kAlignment = Alignment;
399 using Fragment = Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
401 ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >;
408 (kAdvanceRank == 0 ? 0 : 1),
413 "Advance rank may only be along the row or column dimensions.");
429 iterator_({ref.
data(), ref.
stride()}, thread_idx, 4) {
436 iterator_.load_with_pointer_offset(frag, pointer_offset);
442 iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()});
448 iterator_.load_with_pointer_offset(frag, 0);
454 iterator_.store_with_pointer_offset(frag, pointer_offset);
460 iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()});
466 iterator_.store_with_pointer_offset(frag, 0);
486 iterator_.add_pointer_offset(pointer_offset);
492 iterator_.add_tile_offset({coord.row(), coord.column()});
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:355
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
static int const value
Definition: numeric_types.h:43
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:249
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:352
Aligned array type.
Definition: array.h:511
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:246
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object's stride vector.
Definition: tensor_ref.h:277
Defines the size of an element in bits.
Definition: numeric_types.h:42
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Defines layout functions used by TensorRef and derived classes.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
Basic include for CUTLASS.
Definition: matrix_coord.h:39