CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
fragment_iterator_volta_tensor_op.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
38 #pragma once
39 
40 #include "cutlass/array.h"
41 #include "cutlass/layout/matrix.h"
42 #include "cutlass/gemm/gemm.h"
43 
45 
47 
48 namespace cutlass {
49 namespace epilogue {
50 namespace warp {
51 
53 
55 template <
56  typename WarpShape,
57  typename InterleavedTileShape,
58  typename ElementC,
59  typename Layout
60 >
62 
64 
66 template <
67  typename WarpShape_
68 >
69 class FragmentIteratorVoltaTensorOp<WarpShape_, gemm::GemmShape<32, 32, 4>, half_t, layout::RowMajor> {
70 public:
71 
72  using WarpShape = WarpShape_;
73  using InterleavedTileShape = gemm::GemmShape<32, 32, 4>;
74  using ElementC = half_t;
76 
79 
81  using AccessType = typename Policy::AccessType;
82 
84  using Fragment = typename Policy::Fragment;
85 
87  using AccumulatorTile = typename Policy::AccumulatorTile;
88 
90 
92  static int const kIterations = Policy::kIterations;
93 
94 private:
95 
96 private:
97 
98  //
99  // Data members
100  //
101 
103  AccessType const *accumulators_;
104 
106  int index_;
107 
108 public:
109 
113  accumulators_(reinterpret_cast<AccessType const *>(&accum)),
114  index_(0) {
115 
116  }
117 
121  ++index_;
122  return *this;
123  }
124 
128  --index_;
129  return *this;
130  }
131 
134  void load(Fragment &frag, int index_offset = 0) const {
135 
136  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
137 
138  static int const kAccessesPerMma = Policy::kElementsPerMma / Policy::kElementsPerAccess;
139 
141  for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
142 
143  int tile_access_idx =
144  (tile_n * Policy::TileIterations::kRow + (index_ & 2) / 2) * Policy::MmaIterations::kCount * kAccessesPerMma;
145 
147  for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * kAccessesPerMma; ++mma_n) {
148 
149  int mma_access_idx = ((mma_n & 1) * 2 + (index_ & 1)) * kAccessesPerMma + (mma_n & 2) / 2;
150 
151  frag_ptr[tile_n * Policy::MmaIterations::kColumn * kAccessesPerMma +
152  mma_n] = accumulators_[tile_access_idx + mma_access_idx];
153  }
154  }
155  }
156 };
157 
159 
161 template <
162  typename WarpShape_
163 >
164 class FragmentIteratorVoltaTensorOp<WarpShape_, gemm::GemmShape<32, 32, 4>, float, layout::RowMajor> {
165 public:
166 
167  using WarpShape = WarpShape_;
168  using InterleavedTileShape = gemm::GemmShape<32, 32, 4>;
169  using ElementC = float;
171 
174 
176  using AccessType = typename Policy::AccessType;
177 
179  using Fragment = typename Policy::Fragment;
180 
182  using AccumulatorTile = typename Policy::AccumulatorTile;
183 
185  static int const kIterations = Policy::kIterations;
186 
187 private:
188 
189 private:
190 
191  //
192  // Data members
193  //
194 
196  AccessType const *accumulators_;
197 
199  int index_;
200 
201 public:
202 
206  accumulators_(reinterpret_cast<AccessType const *>(&accum)),
207  index_(0) {
208  }
209 
213  ++index_;
214  return *this;
215  }
216 
220  --index_;
221  return *this;
222  }
223 
226  void load(Fragment &frag, int index_offset = 0) const {
227 
228  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
229 
230  int const kRegsPerMmaRow = 2;
231 
233  for (int reg_row = 0; reg_row < Policy::kRowsPerMmaTile; ++reg_row) {
234 
236  for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
237 
239  for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * 2; ++mma_n) {
240 
241  int mma_idx = (index_ & 1) + (index_ & 2) * Policy::MmaIterations::kCount / 2 +
242  (tile_n * Policy::TileIterations::kRow) * Policy::MmaIterations::kCount + (mma_n & 1) * 2;
243 
244  int reg_offset = reg_row * kRegsPerMmaRow + (mma_n & 2) * 2;
245  int reg_idx = mma_idx * Policy::kElementsPerMma + reg_offset;
246 
247  *frag_ptr = accumulators_[reg_idx / Policy::kElementsPerAccess];
248  ++frag_ptr;
249  }
250  }
251  }
252  }
253 };
254 
256 
257 
258 } // namespace warp
259 } // namespace epilogue
260 } // namespace cutlass
261 
263 
typename Policy::AccumulatorTile AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:182
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator++()
Increments.
Definition: fragment_iterator_volta_tensor_op.h:120
Definition: aligned_buffer.h:35
typename Policy::AccessType AccessType
Array type for aligned memory accesses.
Definition: fragment_iterator_volta_tensor_op.h:176
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator--()
Decrements.
Definition: fragment_iterator_volta_tensor_op.h:219
IEEE half-precision floating-point type.
Definition: half.h:126
typename Policy::AccumulatorTile AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:87
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator++()
Increments.
Definition: fragment_iterator_volta_tensor_op.h:212
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_volta_tensor_op.h:205
typename Policy::AccessType AccessType
Array type for aligned memory accesses.
Definition: fragment_iterator_volta_tensor_op.h:81
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator--()
Decrements.
Definition: fragment_iterator_volta_tensor_op.h:127
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:134
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Definition: fragment_iterator_volta_tensor_op.h:61
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:226
Defines layout functions used by TensorRef and derived classes.
Policy details related to the epilogue.
Definition: volta_tensor_op_policy.h:52
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_volta_tensor_op.h:112
typename Policy::Fragment Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_volta_tensor_op.h:84
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a &#39;column-major&#39; arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
typename Policy::Fragment Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_volta_tensor_op.h:179