56 typename OperatorShape,
57 typename OperatorElementC,
58 typename OperatorFragmentC,
69 typename OperatorShape_,
70 typename OperatorElementC_,
71 typename OperatorFragmentC_
77 using OperatorShape = OperatorShape_;
78 using OperatorElementC = OperatorElementC_;
79 using OperatorFragmentC = OperatorFragmentC_;
87 Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
89 static int const kRealIndex = 0;
92 static int const kImaginaryIndex =
93 OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn;
102 static int const kIterations = Policy::kIterations;
107 using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
109 using FragmentAccessType = Array<complex<OperatorElementC>, Policy::kElementsPerAccess>;
118 AccessType
const *accumulators_;
128 accumulators_(reinterpret_cast<AccessType const *>(&accum)),
151 int index = index_ + index_offset;
153 FragmentAccessType *frag_ptr =
reinterpret_cast<FragmentAccessType *
>(&frag);
156 for (
int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
158 int accumulator_access_offset =
159 index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
161 auto const & real_accum_array = accumulators_[accumulator_access_offset + kRealIndex];
162 auto const & imag_accum_array = accumulators_[accumulator_access_offset + kImaginaryIndex / Policy::kElementsPerAccess];
166 for (
int i = 0; i < Policy::kElementsPerAccess; ++i) {
168 frag_ptr[n][i].real() = real_accum_array[i];
169 frag_ptr[n][i].imag() = imag_accum_array[i];
Definition: aligned_buffer.h:35
Array< complex< OperatorElementC >, Policy::OperatorCount::kColumn *Policy::kElementsPerAccess > Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_complex_tensor_op.h:87
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
Array< OperatorElementC, 2 *kImaginaryIndex > AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_complex_tensor_op.h:96
CUTLASS_HOST_DEVICE FragmentIteratorComplexTensorOp & operator++()
Increments.
Definition: fragment_iterator_complex_tensor_op.h:135
CUTLASS_HOST_DEVICE FragmentIteratorComplexTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_complex_tensor_op.h:127
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
WarpShape_ WarpShape
Definition: fragment_iterator_complex_tensor_op.h:76
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE FragmentIteratorComplexTensorOp & operator--()
Decrements.
Definition: fragment_iterator_complex_tensor_op.h:142
Definition: fragment_iterator_complex_tensor_op.h:61
Policy details related to the epilogue.
Definition: tensor_op_policy.h:50
Array< complex< OperatorElementC >, kImaginaryIndex > OutputAccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_complex_tensor_op.h:99
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_complex_tensor_op.h:149
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.