CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
fragment_iterator_complex_tensor_op.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
38 #pragma once
39 
40 #include "cutlass/array.h"
41 #include "cutlass/layout/matrix.h"
42 
44 
46 
47 namespace cutlass {
48 namespace epilogue {
49 namespace warp {
50 
52 
54 template <
55  typename WarpShape,
56  typename OperatorShape,
57  typename OperatorElementC,
58  typename OperatorFragmentC,
59  typename Layout
60 >
62 
64 
65 
67 template <
68  typename WarpShape_,
69  typename OperatorShape_,
70  typename OperatorElementC_,
71  typename OperatorFragmentC_
72 >
73 class FragmentIteratorComplexTensorOp<WarpShape_, OperatorShape_, OperatorElementC_, OperatorFragmentC_, layout::RowMajor> {
74 public:
75 
76  using WarpShape = WarpShape_;
77  using OperatorShape = OperatorShape_;
78  using OperatorElementC = OperatorElementC_;
79  using OperatorFragmentC = OperatorFragmentC_;
81 
83 
85  using Fragment = Array<
87  Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
88 
89  static int const kRealIndex = 0;
90 
92  static int const kImaginaryIndex =
93  OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn;
94 
96  using AccumulatorTile = Array<OperatorElementC, 2 * kImaginaryIndex>;
97 
99  using OutputAccumulatorTile = Array<complex<OperatorElementC>, kImaginaryIndex>;
100 
102  static int const kIterations = Policy::kIterations;
103 
104 private:
105 
107  using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
108 
109  using FragmentAccessType = Array<complex<OperatorElementC>, Policy::kElementsPerAccess>;
110 
111 private:
112 
113  //
114  // Data members
115  //
116 
118  AccessType const *accumulators_;
119 
121  int index_;
122 
123 public:
124 
128  accumulators_(reinterpret_cast<AccessType const *>(&accum)),
129  index_(0) {
130 
131  }
132 
136  ++index_;
137  return *this;
138  }
139 
143  --index_;
144  return *this;
145  }
146 
149  void load(Fragment &frag, int index_offset = 0) const {
150 
151  int index = index_ + index_offset;
152 
153  FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
154 
156  for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
157 
158  int accumulator_access_offset =
159  index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
160 
161  auto const & real_accum_array = accumulators_[accumulator_access_offset + kRealIndex];
162  auto const & imag_accum_array = accumulators_[accumulator_access_offset + kImaginaryIndex / Policy::kElementsPerAccess];
163 
164  // Pack real and imaginary parts into a structure. This is likely to result in MOVs
166  for (int i = 0; i < Policy::kElementsPerAccess; ++i) {
167 
168  frag_ptr[n][i].real() = real_accum_array[i];
169  frag_ptr[n][i].imag() = imag_accum_array[i];
170  }
171  }
172  }
173 };
174 
176 
177 } // namespace warp
178 } // namespace epilogue
179 } // namespace cutlass
180 
Definition: aligned_buffer.h:35
Array< complex< OperatorElementC >, Policy::OperatorCount::kColumn *Policy::kElementsPerAccess > Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_complex_tensor_op.h:87
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a &#39;column-major&#39; arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
Array< OperatorElementC, 2 *kImaginaryIndex > AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_complex_tensor_op.h:96
CUTLASS_HOST_DEVICE FragmentIteratorComplexTensorOp & operator++()
Increments.
Definition: fragment_iterator_complex_tensor_op.h:135
CUTLASS_HOST_DEVICE FragmentIteratorComplexTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_complex_tensor_op.h:127
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE FragmentIteratorComplexTensorOp & operator--()
Decrements.
Definition: fragment_iterator_complex_tensor_op.h:142
Definition: fragment_iterator_complex_tensor_op.h:61
Policy details related to the epilogue.
Definition: tensor_op_policy.h:50
Array< complex< OperatorElementC >, kImaginaryIndex > OutputAccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_complex_tensor_op.h:99
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_complex_tensor_op.h:149
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Definition: complex.h:92
Defines layout functions used by TensorRef and derived classes.