35 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 75 class MmaTensorOpWmmaMultiplicandTileIterator;
97 class MmaTensorOpWmmaMultiplicandTileIterator<
98 Shape_, Operand::
kA, Element_, Layout_,
99 OpDelta_, 32, Policy_> {
103 using Shape = Shape_;
109 using Element = Element_;
112 using Layout = Layout_;
115 static int const kOpDelta = OpDelta_;
118 using Policy = Policy_;
125 using TensorRef = TensorRef<Element, Layout>;
137 using WmmaShape = MatrixShape<
138 Policy::Operator::Shape::kM,
139 Policy::Operator::Shape::kK
143 using WmmaDataType =
typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
146 using Iterations = MatrixShape<
147 Shape::kRow / WmmaShape::kRow,
152 using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentA, Iterations::kCount>;
160 "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for A operands to warp-level Mma.");
166 "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
170 "Alternative arrangements not supported at present.");
177 char const *pointer_;
192 MmaTensorOpWmmaMultiplicandTileIterator() { }
196 MmaTensorOpWmmaMultiplicandTileIterator(
197 TensorRef
const &ref,
199 ): pointer_(reinterpret_cast<char const*>(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) {
205 MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
212 MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord
const &tile_offset) {
214 Index elements_offset = layout_({tile_offset.row() * Shape::kRow, tile_offset.column() * WmmaShape::kColumn});
223 MmaTensorOpWmmaMultiplicandTileIterator &
operator++() {
225 Index elements_offset = layout_({0, WmmaShape::kColumn});
234 MmaTensorOpWmmaMultiplicandTileIterator &
operator--() {
236 Index elements_offset = layout_({0, WmmaShape::kColumn});
245 MmaTensorOpWmmaMultiplicandTileIterator &
operator+=(TensorCoord
const &tile_offset) {
246 add_tile_offset(tile_offset);
252 MmaTensorOpWmmaMultiplicandTileIterator &
operator-=(TensorCoord
const &tile_offset) {
253 add_tile_offset(-tile_offset);
259 void load_with_byte_offset(Fragment &frag, Index byte_offset)
const {
262 for (
int k = 0; k < Iterations::kColumn; ++k) {
264 for (
int m = 0; m < Iterations::kRow; ++m) {
268 const WmmaDataType *ptr =
reinterpret_cast<const WmmaDataType *
>(pointer_ + byte_offset_ + load_byte_offset + byte_offset);
270 nvcuda::wmma::load_matrix_sync(frag[m], ptr, stride_);
277 void load(Fragment &frag)
const {
278 load_with_byte_offset(frag, 0);
283 void store_with_byte_offset(Fragment
const &frag, Index byte_offset)
const {
286 for (
int k = 0; k < Iterations::kColumn; ++k) {
288 for (
int m = 0; m < Iterations::kRow; ++m) {
292 WmmaDataType *ptr =
reinterpret_cast<WmmaDataType *
>(pointer_ + byte_offset_ + store_byte_offset + byte_offset);
294 nvcuda::wmma::store_matrix_sync(ptr, frag[m], stride_);
302 void store(Fragment
const &frag)
const {
303 store_with_byte_offset(frag, 0);
314 void set_kgroup_index(
int k_group) {
341 class MmaTensorOpWmmaMultiplicandTileIterator<
342 Shape_, Operand::
kB, Element_, Layout_,
343 OpDelta_, 32, Policy_> {
347 using Shape = Shape_;
353 using Element = Element_;
356 using Layout = Layout_;
359 static int const kOpDelta = OpDelta_;
362 using Policy = Policy_;
370 using TensorRef = TensorRef<Element, Layout>;
382 using WmmaShape = MatrixShape<
383 Policy::Operator::Shape::kK,
384 Policy::Operator::Shape::kN
388 using WmmaDataType =
typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
391 using Iterations = MatrixShape<
393 Shape::kColumn / WmmaShape::kColumn
397 using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentB, Iterations::kCount>;
405 "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for B operands to warp-level Mma.");
411 "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
415 "Alternative arrangements not supported at present.");
422 char const *pointer_;
437 MmaTensorOpWmmaMultiplicandTileIterator() { }
441 MmaTensorOpWmmaMultiplicandTileIterator(
442 TensorRef
const &ref,
444 ): pointer_(reinterpret_cast<char const*>(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) {
449 MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
458 MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord
const &tile_offset) {
460 Index elements_offset = layout_({tile_offset.row() * WmmaShape::kRow, tile_offset.column() * Shape::kColumn});
469 MmaTensorOpWmmaMultiplicandTileIterator &
operator++() {
471 Index elements_offset = layout_({WmmaShape::kRow, 0});
480 MmaTensorOpWmmaMultiplicandTileIterator &
operator--() {
482 Index elements_offset = layout_({WmmaShape::kRow, 0});
490 MmaTensorOpWmmaMultiplicandTileIterator &
operator+=(TensorCoord
const &tile_offset) {
491 add_tile_offset(tile_offset);
497 MmaTensorOpWmmaMultiplicandTileIterator &
operator-=(TensorCoord
const &tile_offset) {
498 add_tile_offset(-tile_offset);
504 void load_with_byte_offset(Fragment &frag, Index byte_offset)
const {
507 for (
int k = 0; k < Iterations::kRow; ++k) {
509 for (
int n = 0; n < Iterations::kColumn; ++n) {
513 const WmmaDataType *ptr =
reinterpret_cast<const WmmaDataType *
>(pointer_ + byte_offset_ + load_byte_offset + byte_offset);
515 nvcuda::wmma::load_matrix_sync(frag[n], ptr, stride_);
521 void load(Fragment &frag)
const {
522 load_with_byte_offset(frag, 0);
527 void store_with_byte_offset(Fragment
const &frag, Index byte_offset)
const {
530 for (
int k = 0; k < Iterations::kRow; ++k) {
532 for (
int n = 0; n < Iterations::kColumn; ++n) {
536 WmmaDataType *ptr =
reinterpret_cast<WmmaDataType *
>(pointer_ + byte_offset_ + store_byte_offset + byte_offset);
538 nvcuda::wmma::store_matrix_sync(ptr, frag[n], stride_);
545 void store(Fragment
const &frag)
const {
546 store_with_byte_offset(frag, 0);
557 void set_kgroup_index(
int k_group) {
574 class MmaTensorOpWmmaAccumulatorTileIterator;
598 class MmaTensorOpWmmaAccumulatorTileIterator
603 using Shape = Shape_;
606 using Element = Element_;
609 using Layout = Layout_;
612 using OpDelta = OpDelta_;
615 static int const kThreads = 32;
618 using Policy = Policy_;
625 using TensorRef = TensorRef<Element, Layout>;
637 using WmmaShape = MatrixShape<
638 Policy::Operator::Shape::kM,
639 Policy::Operator::Shape::kN
643 using WmmaDataType =
typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
646 static nvcuda::wmma::layout_t
const WmmaLayout = cutlass::arch::CutlassToWmmaLayout<Layout>::value;
649 using Iterations = MatrixShape<
650 Shape::kRow / WmmaShape::kRow,
651 Shape::kColumn / WmmaShape::kColumn
655 using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentC, Iterations::kCount>;
664 "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
675 MmaTensorOpWmmaAccumulatorTileIterator() { }
679 MmaTensorOpWmmaAccumulatorTileIterator(
680 TensorRef
const &ref,
686 MmaTensorOpWmmaAccumulatorTileIterator &add_pointer_offset(LongIndex offset) {
693 MmaTensorOpWmmaAccumulatorTileIterator &add_tile_offset(TensorCoord
const &tile_offset) {
694 ref_.
add_coord_offset({tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn});
700 MmaTensorOpWmmaAccumulatorTileIterator &
operator++() {
707 MmaTensorOpWmmaAccumulatorTileIterator &
operator--() {
714 MmaTensorOpWmmaAccumulatorTileIterator &
operator+=(TensorCoord
const &tile_offset) {
715 add_tile_offset(tile_offset);
721 MmaTensorOpWmmaAccumulatorTileIterator &
operator-=(TensorCoord
const &tile_offset) {
722 add_tile_offset(-tile_offset);
728 void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
const {
731 for (
int m = 0; m < Iterations::kRow; ++m) {
733 for (
int n = 0; n < Iterations::kColumn; ++n) {
735 const WmmaDataType * ptr =
reinterpret_cast<const WmmaDataType*
> (ref_.
data() + ref_.
offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset);
737 nvcuda::wmma::load_matrix_sync(frag[m * Iterations::kColumn + n], ptr, ref_.
stride()[0], WmmaLayout);
744 void load(Fragment &frag)
const {
745 load_with_pointer_offset(frag, 0);
750 void store_with_pointer_offset(Fragment
const &frag, Index pointer_offset)
const {
753 for (
int m = 0; m < Iterations::kRow; ++m) {
755 for (
int n = 0; n < Iterations::kColumn; ++n) {
757 WmmaDataType * ptr =
reinterpret_cast<WmmaDataType*
> (ref_.
data() + ref_.
offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset);
759 nvcuda::wmma::store_matrix_sync(ptr, frag[m * Iterations::kColumn + n], ref_.
stride()[0], WmmaLayout);
766 void store(Fragment
const &frag)
const {
767 store_with_pointer_offset(frag, 0);
778 void set_kgroup_index(
int k_group) {
791 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
Architecture-specific operators on memory added for SM75.
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)
Definition: half.h:654
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)
Definition: half.h:664
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object's stride vector.
Definition: tensor_ref.h:277
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles.
CUTLASS_HOST_DEVICE half_t & operator--(half_t &lhs)
Definition: half.h:706
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
Defines layout functions used by TensorRef and derived classes.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168