37 #define CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES 1 53 typename MmaSimtPolicy
64 typename MmaSimtPolicy_
66 class TileIteratorSimt<WarpShape_, Operator_, Element_, layout::RowMajor, MmaSimtPolicy_> {
70 using Operator = Operator_;
71 using Element = Element_;
83 Policy::kRowsPerIteration,
89 typename Operator::ElementC,
90 Policy::kElementsPerIteration>;
94 typename Operator::ElementC,
95 Policy::kAccumulatorElementCount>;
98 static int const kIterations = Policy::kIterations;
103 4 * Policy::kElementsPerAccess>;
132 pointer_(reinterpret_cast<
AccessType *>(ref.data())),
133 layout_(ref.stride()[0] /
Policy::kElementsPerAccess) {
135 auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout();
136 MatrixCoord lane_offset = lane_layout.inverse(lane_id);
138 pointer_ += layout_(lane_offset);
144 pointer_ += pointer_offset / Policy::kElementsPerAccess;
152 pointer_ += layout_({
153 tile_offset.
row() * Shape::kRow,
154 (tile_offset.
column() * Shape::kColumn / Policy::kElementsPerAccess)
164 add_tile_offset(tile_offset);
172 #if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES 175 ScalarAccessType
const *scalarFragPtr =
reinterpret_cast<ScalarAccessType
const *
>(&frag);
176 ScalarAccessType *scalarPointer =
reinterpret_cast<ScalarAccessType *
>(pointer_);
179 for (
int n = 0; n < Policy::kAccessesPerIteration; ++n) {
181 for (
int s = 0; s < Policy::kElementsPerAccess; s++) {
182 scalarPointer[n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s] = scalarFragPtr[n * Policy::kElementsPerAccess + s];
189 for (
int n = 0; n < Policy::kAccessesPerIteration; ++n) {
190 pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn] = frag_ptr[n];
198 store_with_pointer_offset(frag, 0);
208 for (
int n = 0; n < Policy::kAccessesPerIteration; ++n) {
209 frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn];
216 load_with_pointer_offset(frag, 0);
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
Definition: aligned_buffer.h:35
Array< typename Operator::ElementC, Policy::kAccumulatorElementCount > AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: tile_iterator_simt.h:95
Aligned array type.
Definition: array.h:511
Array< typename Operator::ElementC, Policy::kElementsPerIteration > Fragment
This is the fragment size produced by one access of the iterator.
Definition: tile_iterator_simt.h:90
typename TensorRef::Index Index
Definition: tile_iterator_simt.h:76
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
typename TensorRef::LongIndex LongIndex
Definition: tile_iterator_simt.h:77
Definition: simt_policy.h:50
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Load.
Definition: tile_iterator_simt.h:215
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE void store(Fragment const &frag)
Store.
Definition: tile_iterator_simt.h:197
CUTLASS_HOST_DEVICE TileIteratorSimt & add_tile_offset(TensorCoord const &tile_offset)
advances in units of whole tiles along the logical coordinate space of the tensor ...
Definition: tile_iterator_simt.h:150
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store.
Definition: tile_iterator_simt.h:171
CUTLASS_HOST_DEVICE TileIteratorSimt & add_pointer_offset(Index pointer_offset)
Adds a pointer offset.
Definition: tile_iterator_simt.h:143
CUTLASS_HOST_DEVICE TileIteratorSimt(TensorRef const &ref, unsigned lane_id)
Constructor from TensorRef.
Definition: tile_iterator_simt.h:128
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE TileIteratorSimt()
Default constructor.
Definition: tile_iterator_simt.h:124
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
CUTLASS_HOST_DEVICE TileIteratorSimt & operator+=(TensorCoord const &tile_offset)
Definition: tile_iterator_simt.h:162
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Load.
Definition: tile_iterator_simt.h:203
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of SimtOp instructions, of which a row-oriented slice is visible per iteration.
WarpShape_ WarpShape
Definition: tile_iterator_simt.h:69
Definition: matrix_coord.h:39
Template for reading and writing tiles of accumulators to shared memory.
Definition: tile_iterator_simt.h:55
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168