31 #if !defined(__clang__) 52 typename OperatorShape,
53 typename OperatorFragment,
63 typename OperatorShape_,
64 typename OperatorFragment_
70 using OperatorShape = OperatorShape_;
71 using OperatorFragment = OperatorFragment_;
78 using Element =
typename cutlass::arch::WmmaToCutlassDataType<WmmaDataType>::Type;
84 using Policy = WmmaTensorOpPolicy<WarpShape, OperatorShape, Layout>;
88 Policy::kRowsPerIteration,
93 using Fragment = WmmaFragmentArray<OperatorFragment, Policy::OperatorCount::kColumn * Policy::kWmmaFragmentsPerAccess>;
104 4 * Policy::kElementsPerAccess
153 add_tile_offset(tile_offset);
161 for(
int n=0; n < Policy::OperatorCount::kColumn; n++) {
165 nvcuda::wmma::store_matrix_sync(
169 nvcuda::wmma::layout_t::mem_row_major
178 store_with_pointer_offset(frag, 0);
185 for(
int n=0; n < Policy::OperatorCount::kColumn; n++) {
189 nvcuda::wmma::load_matrix_sync(
193 nvcuda::wmma::layout_t::mem_row_major
202 load_with_pointer_offset(frag, 0);
214 #endif // !defined(__clang__) Describes the size of a matrix tile.
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp & add_pointer_offset(Index pointer_offset)
Adds a pointer offset.
Definition: tile_iterator_wmma_tensor_op.h:138
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp & add_tile_offset(TensorCoord const &tile_offset)
advances in units of whole tiles along the logical coordinate space of the tensor ...
Definition: tile_iterator_wmma_tensor_op.h:145
Definition: aligned_buffer.h:35
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp()
Default constructor.
Definition: tile_iterator_wmma_tensor_op.h:124
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
WmmaTensorOpPolicy< WarpShape, OperatorShape, Layout > Policy
Definition: tile_iterator_wmma_tensor_op.h:84
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
typename TensorRef::Index Index
Definition: tile_iterator_wmma_tensor_op.h:81
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp & operator+=(TensorCoord const &tile_offset)
Definition: tile_iterator_wmma_tensor_op.h:152
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
typename OperatorFragment::element_type WmmaDataType
Definition: tile_iterator_wmma_tensor_op.h:77
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object's stride vector.
Definition: tensor_ref.h:277
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Load.
Definition: tile_iterator_wmma_tensor_op.h:201
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store.
Definition: tile_iterator_wmma_tensor_op.h:159
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp(TensorRef const &ref, unsigned lane_id)
Constructor from TensorRef.
Definition: tile_iterator_wmma_tensor_op.h:130
WarpShape_ WarpShape
Definition: tile_iterator_wmma_tensor_op.h:69
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Load.
Definition: tile_iterator_wmma_tensor_op.h:183
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
typename TensorRef::LongIndex LongIndex
Definition: tile_iterator_wmma_tensor_op.h:82
WmmaFragmentArray< OperatorFragment, Policy::OperatorCount::kColumn *Policy::kWmmaFragmentsPerAccess > Fragment
This is the fragment size produced by one access of the iterator.
Definition: tile_iterator_wmma_tensor_op.h:93
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
Defines layout functions used by TensorRef and derived classes.
typename cutlass::arch::WmmaToCutlassDataType< WmmaDataType >::Type Element
Data Type of element stored in nvcuda::wmma::frament.
Definition: tile_iterator_wmma_tensor_op.h:78
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Template for reading and writing tiles of accumulators to shared memory.
Definition: tile_iterator_wmma_tensor_op.h:56
CUTLASS_HOST_DEVICE void store(Fragment const &frag)
Store.
Definition: tile_iterator_wmma_tensor_op.h:177
Basic include for CUTLASS.
Definition: matrix_coord.h:39
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168