CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
33 #pragma once
34 
35 #include <assert.h>
36 
37 #include "cutlass/cutlass.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/array.h"
40 #include "cutlass/layout/vector.h"
41 #include "cutlass/layout/tensor.h"
42 #include "cutlass/tensor_coord.h"
43 #include "cutlass/aligned_buffer.h"
44 #include "cutlass/functional.h"
45 
46 #include "cutlass/gemm/gemm.h"
47 
50 
53 
55 
56 namespace cutlass {
57 namespace epilogue {
58 namespace threadblock {
59 
61 
63 template <
64  typename Shape_,
65  typename WarpMmaOperator_,
66  int PartitionsK,
67  typename OutputTileIterator_,
68  typename AccumulatorFragmentIterator_,
69  typename WarpTileIterator_,
70  typename SharedLoadIterator_,
71  typename OutputOp_,
72  typename Padding_
73 >
74 class Epilogue :
75  public EpilogueBase<
76  Shape_,
77  WarpMmaOperator_,
78  PartitionsK,
79  AccumulatorFragmentIterator_,
80  WarpTileIterator_,
81  Padding_> {
82 
83 public:
84 
85  using Base = EpilogueBase<
86  Shape_,
87  WarpMmaOperator_,
88  PartitionsK,
89  AccumulatorFragmentIterator_,
90  WarpTileIterator_,
91  Padding_>;
92 
93  using Shape = Shape_;
94  using WarpMmaOperator = WarpMmaOperator_;
95  static int const kPartitionsK = PartitionsK;
96  using OutputTileIterator = OutputTileIterator_;
97  using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
98  using WarpTileIterator = WarpTileIterator_;
99  using SharedLoadIterator = SharedLoadIterator_;
100  using OutputOp = OutputOp_;
101  using Padding = Padding_;
102 
105  using LongIndex = typename Layout::LongIndex;
106 
109 
111  using ElementAccumulator = typename WarpTileIterator::Element;
112 
113 
115  using ElementOutput = typename OutputTileIterator::Element;
116 
118  static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
119 
121  using TensorRef = typename OutputTileIterator::TensorRef;
122 
125 
127  using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
128 
130  using OutputAccessType = Array<
131  typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
132 
134  using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
135 
137  using WarpCount = typename Base::WarpCount;
138 
139 public:
140 
141 
142  static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
143  "Mismatch between shared load iterator and output tile iterator.");
144 
145  static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
146 
147  static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
148  "Divisibility");
149 
150 private:
151 
153  SharedLoadIterator shared_load_iterator_;
154 
155 public:
156 
158  CUTLASS_DEVICE
160  typename Base::SharedStorage &shared_storage,
161  int thread_idx,
162  int warp_idx,
163  int lane_idx
164  ):
165  Base(shared_storage, thread_idx, warp_idx, lane_idx),
166  shared_load_iterator_(shared_storage.reference(), thread_idx) { }
167 
169  CUTLASS_DEVICE
171  OutputOp const &output_op,
172  OutputTileIterator destination_iterator,
173  AccumulatorTile const &accumulators,
174  OutputTileIterator source_iterator) {
175 
176 
177  typename OutputTileIterator::Fragment source_fragment;
178 
179  if (!output_op.is_source_needed()) {
180  source_iterator.clear_mask();
181  }
182 
183  source_fragment.clear();
184 
185  //
186  // Iterator over warp-level accumulator fragment
187  //
188 
189  AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
190 
191  //
192  // Iterate over accumulator tile
193  //
194 
196  for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
197 
198  //
199  // Load the source
200  //
201 
202  source_iterator.load(source_fragment);
203  ++source_iterator;
204 
205  //
206  // Convert and store fragment
207  //
208 
209  __syncthreads();
210 
211  typename AccumulatorFragmentIterator::Fragment accum_fragment;
212 
213  accum_fragment_iterator.load(accum_fragment);
214  ++accum_fragment_iterator;
215 
216  this->warp_tile_iterator_.store(accum_fragment);
217 
218  __syncthreads();
219 
220  //
221  // Load fragments from shared memory
222  //
223 
224  typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
225 
226  shared_load_iterator_.load(aligned_accum_fragment[0]);
227 
228  // If the number of k-slices is > 1 - perform a reduction amongst the k-slices
229  if (kPartitionsK > 1)
230  {
232  const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
233 
235  for ( int i = 1; i < kPartitionsK; ++i) {
236  shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
237  shared_load_iterator_.load(aligned_accum_fragment[i]);
238  aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
239  }
240 
241  shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
242  }
243 
244  //
245  // Compute the output result
246  //
247 
248  typename OutputTileIterator::Fragment output_fragment;
249 
250  apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
251 
252 
253  //
254  // Store the final result
255  //
256 
257  destination_iterator.store(output_fragment);
258  ++destination_iterator;
259 
260  }
261  }
262 
263 private:
264 
266  CUTLASS_DEVICE
267  void apply_output_operator_(
268  typename OutputTileIterator::Fragment &output_fragment,
269  OutputOp const &output_op,
270  typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
271  typename OutputTileIterator::Fragment const &source_fragment) {
272 
273  OutputAccessType *output_frag_ptr =
274  reinterpret_cast<OutputAccessType *>(&output_fragment);
275 
276  AccumulatorAccessType const *compute_frag_ptr =
277  reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
278 
279  OutputAccessType const *source_frag_ptr =
280  reinterpret_cast<OutputAccessType const *>(&source_fragment);
281 
282  int const kOutputOpIterations =
283  OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
284 
286  for (int i = 0; i < kOutputOpIterations; ++i) {
287 
288  // Call the output operator
289  output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
290  }
291  }
292 };
293 
295 
296 } // namespace threadblock
297 } // namespace epilogue
298 } // namespace cutlass
299 
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
typename Layout::LongIndex LongIndex
Definition: epilogue.h:105
typename Base::WarpCount WarpCount
Number of warps.
Definition: epilogue.h:137
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: shared_load_iterator.h:91
Definition: aligned_buffer.h:35
WarpTileIterator warp_tile_iterator_
Stores a warp&#39;s fragment of accumulators to SMEM.
Definition: epilogue_base.h:176
Templates implementing how threads are mapped to a given tile.
Shared storage allocation needed by the epilogue.
Definition: epilogue_base.h:97
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory.
Definition: epilogue.h:170
OutputTileIterator_ OutputTileIterator
Definition: epilogue.h:96
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE Epilogue(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: epilogue.h:159
Shape_ Shape
Definition: epilogue.h:93
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor.
Definition: epilogue.h:121
gemm::GemmShape< Shape::kM/WarpMmaOperator::Shape::kM, Shape::kN/WarpMmaOperator::Shape::kN, kPartitionsK > WarpCount
Number of warps.
Definition: epilogue_base.h:92
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
static int const kPartitionsK
Definition: epilogue.h:95
OutputOp_ OutputOp
Definition: epilogue.h:100
Definition: tensor_ref.h:146
Padding_ Padding
Definition: epilogue.h:101
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: epilogue.h:97
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor.
Definition: epilogue.h:127
WarpTileIterator_ WarpTileIterator
Definition: epilogue.h:98
SharedLoadIterator_ SharedLoadIterator
Definition: epilogue.h:99
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Epilogue operator without splitk.
Definition: epilogue.h:74
typename WarpTileIterator::Element ElementAccumulator
Accumulator element.
Definition: epilogue.h:111
Defines layout functions used for rank=1 vectors.
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Base class for epilogues defining warp-level.
Definition: epilogue_base.h:67
WarpMmaOperator_ WarpMmaOperator
Definition: epilogue.h:94
typename Base::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue.h:108
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output.
Definition: epilogue.h:131
static int const kElementsPerAccess
Output access size.
Definition: epilogue.h:118
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue_base.h:81
typename OutputTileIterator::Element ElementOutput
Output element.
Definition: epilogue.h:115
Basic include for CUTLASS.
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor.
Definition: epilogue.h:124
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Array< typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor.
Definition: epilogue.h:134