40 #if !defined(__clang__) 58 typename OperatorShape,
59 typename OperatorElementC,
60 typename OperatorFragmentC,
70 typename OperatorShape_,
71 typename OperatorElementC_,
72 typename OperatorFragmentC_
78 using OperatorShape = OperatorShape_;
79 using OperatorElementC = OperatorElementC_;
80 using OperatorFragmentC = OperatorFragmentC_;
83 using Policy = WmmaTensorOpPolicy<WarpShape, OperatorShape, Layout>;
86 using Fragment = WmmaFragmentArray<OperatorFragmentC, Policy::OperatorCount::kColumn>;
89 using AccumulatorTile = WmmaFragmentArray<OperatorFragmentC, Policy::OperatorCount::kCount>;
96 using AccessType = WmmaFragmentArray<OperatorFragmentC, Policy::kWmmaFragmentsPerAccess>;
105 AccessType
const *accumulators_;
115 accumulators_(reinterpret_cast<AccessType const *>(&accum)),
136 AccessType *frag_ptr =
reinterpret_cast<AccessType *
>(&frag);
139 for(
int n=0; n < Policy::OperatorCount::kColumn; n++) {
141 int accumulator_access_offset = index_ * Policy::OperatorCount::kColumn + n;
143 frag_ptr[n] = accumulators_[accumulator_access_offset];
155 #endif // !defined(__clang__) WmmaFragmentArray< OperatorFragmentC, Policy::OperatorCount::kColumn > Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_wmma_tensor_op.h:86
CUTLASS_HOST_DEVICE FragmentIteratorWmmaTensorOp & operator++()
Increments.
Definition: fragment_iterator_wmma_tensor_op.h:121
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE FragmentIteratorWmmaTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_wmma_tensor_op.h:114
WmmaTensorOpPolicy< WarpShape, OperatorShape, Layout > Policy
Definition: fragment_iterator_wmma_tensor_op.h:83
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_wmma_tensor_op.h:135
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Definition: fragment_iterator_wmma_tensor_op.h:63
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
Defines layout functions used by TensorRef and derived classes.
WarpShape_ WarpShape
Definition: fragment_iterator_wmma_tensor_op.h:77
WmmaFragmentArray< OperatorFragmentC, Policy::OperatorCount::kCount > AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_wmma_tensor_op.h:89
CUTLASS_HOST_DEVICE FragmentIteratorWmmaTensorOp & operator--()
Decrements.
Definition: fragment_iterator_wmma_tensor_op.h:128
AccumulatorTile OutputAccumulatorTile
Definition: fragment_iterator_wmma_tensor_op.h:91