CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
interleaved_epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
33 #pragma once
34 
35 #include <assert.h>
36 
37 #include "cutlass/cutlass.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/array.h"
40 #include "cutlass/layout/vector.h"
41 #include "cutlass/layout/tensor.h"
42 #include "cutlass/tensor_coord.h"
43 #include "cutlass/aligned_buffer.h"
44 
45 #include "cutlass/gemm/gemm.h"
46 
49 
52 
54 
55 namespace cutlass {
56 namespace epilogue {
57 namespace threadblock {
58 
60 
62 template <
64  typename Shape_,
66  typename WarpMmaOperator_,
68  int PartitionsK,
70  typename OutputTileIterator_,
72  typename AccumulatorFragmentIterator_,
74  typename OutputOp_,
76  int InterleavedK,
78  bool IsBetaZero = false>
80  public:
81  using Shape = Shape_;
82  using WarpMmaOperator = WarpMmaOperator_;
83  static int const kPartitionsK = PartitionsK;
84  using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
85  using OutputTileIterator = OutputTileIterator_;
86  using OutputOp = OutputOp_;
87 
90 
92  using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
93 
95  using ElementAccumulator = typename AccumulatorTile::Element;
96 
98  using ElementOutput = typename OutputTileIterator::Element;
99 
101  static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
102 
104  using TensorRef = typename OutputTileIterator::TensorRef;
105 
107  using SyncTensorRef =
109 
111  using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
112 
114  using OutputAccessType = Array<typename OutputTileIterator::Element,
115  OutputTileIterator::kElementsPerAccess>;
116 
118  using AccumulatorAccessType =
119  Array<ElementAccumulator, OutputTileIterator::kElementsPerAccess>;
120 
122  using WarpCount =
123  gemm::GemmShape<Shape::kM / WarpMmaOperator::Shape::kM,
124  Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;
125 
126  public:
127  static_assert(OutputTileIterator::kElementsPerAccess,
128  "This must not be zero.");
129 
130  static_assert(!(OutputTileIterator::Fragment::kElements %
131  OutputTileIterator::kElementsPerAccess),
132  "Divisibility");
133 
135  struct SharedStorage {};
136 
137 
138  public:
140  CUTLASS_DEVICE
142  SharedStorage &shared_storage,
143  int thread_idx,
144  int warp_idx,
145  int lane_idx
146  ) {}
147 
149  CUTLASS_DEVICE
151  OutputOp const &output_op,
152  OutputTileIterator destination_iterator,
153  AccumulatorTile const &accumulators,
154  OutputTileIterator source_iterator) {
155 
156  //
157  // Predicated tile iterators constructed from members
158  //
159 
160  if (IsBetaZero && output_op.is_source_needed())
161  assert(0);
162 
163  typename OutputTileIterator::Fragment source_fragment;
164 
165  if (!IsBetaZero) {
166  if (!output_op.is_source_needed()) {
167  source_iterator.clear_mask();
168  }
169  }
170 
171  source_fragment.clear();
172 
173  //
174  // Iterator over warp-level accumulator fragment
175  //
176 
177  AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
178 
179  //
180  // Iterate over accumulator tile
181  //
182 
184  for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
185  //
186  // Load the source
187  //
188 
189  if (!IsBetaZero) {
190  source_iterator.set_iteration_index(iter);
191  source_iterator.load(source_fragment);
192  ++source_iterator;
193  }
194 
195  //
196  // Convert fragment
197  //
198 
199  typename AccumulatorFragmentIterator::Fragment accum_fragment;
200 
201  accum_fragment_iterator.load(accum_fragment);
202  ++accum_fragment_iterator;
203 
204  //
205  // Compute the output result
206  //
207 
208  typename OutputTileIterator::Fragment output_fragment;
209  apply_output_operator_(output_op, output_fragment, accum_fragment, source_fragment);
210 
211  //
212  // Store the final result
213  //
214 
215  destination_iterator.set_iteration_index(iter);
216  destination_iterator.store(output_fragment);
217  ++destination_iterator;
218  }
219  }
220 
221  private:
223  CUTLASS_DEVICE
224  void apply_output_operator_(
225  OutputOp const &output_op,
226  typename OutputTileIterator::Fragment &output_fragment,
227  typename AccumulatorFragmentIterator::Fragment const
228  &aligned_accum_fragment,
229  typename OutputTileIterator::Fragment const &source_fragment) {
230  OutputAccessType *output_frag_ptr =
231  reinterpret_cast<OutputAccessType *>(&output_fragment);
232 
233  AccumulatorAccessType const *compute_frag_ptr =
234  reinterpret_cast<AccumulatorAccessType const *>(
235  &aligned_accum_fragment);
236 
237  OutputAccessType const *source_frag_ptr =
238  reinterpret_cast<OutputAccessType const *>(&source_fragment);
239 
240  int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
241  OutputTileIterator::kElementsPerAccess;
242 
244  for (int i = 0; i < kOutputOpIterations; ++i) {
245  // Call the output operator
246  output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
247  }
248  }
249 };
250 
252 
253 } // namespace threadblock
254 } // namespace epilogue
255 } // namespace cutlass
256 
Shape_ Shape
Definition: interleaved_epilogue.h:81
Definition: aligned_buffer.h:35
typename AccumulatorTile::Element ElementAccumulator
Accumulator element.
Definition: interleaved_epilogue.h:95
Templates implementing how threads are mapped to a given tile.
CUTLASS_DEVICE InterleavedEpilogue(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: interleaved_epilogue.h:141
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: interleaved_epilogue.h:92
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Array< ElementAccumulator, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor.
Definition: interleaved_epilogue.h:119
Epilogue operator without splitk.
Definition: interleaved_epilogue.h:79
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output.
Definition: interleaved_epilogue.h:115
OutputOp_ OutputOp
Definition: interleaved_epilogue.h:86
Defines common types used for all GEMM-like operators.
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor.
Definition: interleaved_epilogue.h:111
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor.
Definition: interleaved_epilogue.h:104
Shared storage allocation needed by the epilogue.
Definition: interleaved_epilogue.h:135
WarpMmaOperator_ WarpMmaOperator
Definition: interleaved_epilogue.h:82
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor.
Definition: interleaved_epilogue.h:108
Definition: tensor_ref.h:146
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
OutputTileIterator_ OutputTileIterator
Definition: interleaved_epilogue.h:85
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory.
Definition: interleaved_epilogue.h:150
static int const kElementsPerAccess
Output access size.
Definition: interleaved_epilogue.h:101
Defines layout functions used for rank=1 vectors.
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: layout/matrix.h:343
static int const kPartitionsK
Definition: interleaved_epilogue.h:83
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: interleaved_epilogue.h:84
typename OutputTileIterator::Element ElementOutput
Output element.
Definition: interleaved_epilogue.h:98
Basic include for CUTLASS.