CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
epilogue_base.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
33 #pragma once
34 
35 #include <assert.h>
36 
37 #include "cutlass/cutlass.h"
38 #include "cutlass/matrix_shape.h"
39 #include "cutlass/numeric_types.h"
40 #include "cutlass/array.h"
41 #include "cutlass/layout/vector.h"
42 #include "cutlass/layout/tensor.h"
43 #include "cutlass/tensor_coord.h"
44 #include "cutlass/aligned_buffer.h"
45 
46 #include "cutlass/gemm/gemm.h"
47 
49 
51 
52 namespace cutlass {
53 namespace epilogue {
54 namespace threadblock {
55 
57 
59 template <
60  typename Shape_,
61  typename WarpMmaOperator_,
62  int PartitionsK,
63  typename AccumulatorFragmentIterator_,
64  typename WarpTileIterator_,
65  typename Padding_
66 >
67 class EpilogueBase {
68 public:
69 
70  using Shape = Shape_;
71  using WarpMmaOperator = WarpMmaOperator_;
72  static int const kPartitionsK = PartitionsK;
73  using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
74  using WarpTileIterator = WarpTileIterator_;
75  using Padding = Padding_;
76 
79 
81  using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
82 
84  using ElementAccumulator = typename AccumulatorTile::Element;
85 
86 
88  using WarpCount = gemm::GemmShape<
89  Shape::kM / WarpMmaOperator::Shape::kM,
90  Shape::kN / WarpMmaOperator::Shape::kN,
91  kPartitionsK
92  >;
93 
94 public:
95 
97  struct SharedStorage {
98 
99  //
100  // Type definitions
101  //
102 
104  using Element = typename WarpTileIterator::Element;
105 
107  using TensorRef = typename WarpTileIterator::TensorRef;
108 
110  using Layout = typename WarpTileIterator::Layout;
111 
113  using Shape = MatrixShape<
114  WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
115  WarpCount::kN * WarpTileIterator::Shape::kColumn
116  >;
117 
119  using StorageShape = MatrixShape<
120  Shape::kRow + Padding::kRow,
121  Shape::kColumn + Padding::kColumn
122  >;
123 
124  //
125  // Data members
126  //
127 
129 
130  //
131  // Methods
132  //
133 
135  CUTLASS_DEVICE
137  return storage.data();
138  }
139 
141  CUTLASS_DEVICE
143  return TensorRef(
144  storage.data(),
146  }
147 
148  CUTLASS_DEVICE
149  void debug_print() {
150  if (threadIdx.x == 0) {
151 
152  #pragma unroll 1
153  for (int r = 0; r < Shape::kRow; ++r) {
154 
155  #pragma unroll 1
156  for (int c = 0; c < Shape::kColumn; ++c) {
157 
158  printf("%d ", int(storage.data()[r * StorageShape::kColumn + c]));
159  }
160  printf("\n");
161  }
162  }
163  __syncthreads();
164  }
165  };
166 
167 protected:
168 
169  //
170  // Data members
171  //
172 
174 
177 
178 public:
179 
181  CUTLASS_DEVICE
183  SharedStorage &shared_storage,
184  int thread_idx,
185  int warp_idx,
186  int lane_idx
187  ):
188  shared_storage_(shared_storage),
189  warp_tile_iterator_(shared_storage.reference(), lane_idx) {
190 
191  // Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
192  //
193  // _m: the warp's position within the threadblock along the M dimension
194  // _n: the warp's position within the threadblock along the N dimension
195  // _k: the warp's position within the threadblock along the K dimension
196 
197  int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
198  int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
199  int warp_m = warp_mn % WarpCount::kM;
200  int warp_n = warp_mn / WarpCount::kM;
201 
202  MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
203 
204  warp_tile_iterator_.add_tile_offset(warp_offset);
205  }
206 };
207 
209 
210 } // namespace threadblock
211 } // namespace epilogue
212 } // namespace cutlass
213 
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
CUTLASS_DEVICE void debug_print()
Definition: epilogue_base.h:149
static int const kColumn
columns of a matrix
Definition: matrix_shape.h:44
WarpTileIterator warp_tile_iterator_
Stores a warp&#39;s fragment of accumulators to SMEM.
Definition: epilogue_base.h:176
SharedStorage & shared_storage_
Definition: epilogue_base.h:173
Templates implementing how threads are mapped to a given tile.
WarpMmaOperator_ WarpMmaOperator
Definition: epilogue_base.h:71
Shared storage allocation needed by the epilogue.
Definition: epilogue_base.h:97
CUTLASS_DEVICE Element * data()
Returns a pointer to the shared memory buffer.
Definition: epilogue_base.h:136
Defines common types used for all GEMM-like operators.
typename AccumulatorTile::Element ElementAccumulator
Accumulator element.
Definition: epilogue_base.h:84
typename WarpTileIterator::TensorRef TensorRef
Tensor reference to shared memory allocation.
Definition: epilogue_base.h:107
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
static int const kK
Definition: include/cutlass/gemm/gemm.h:60
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
Defines a Shape template for matrix tiles.
static int const kPartitionsK
Definition: epilogue_base.h:72
typename WarpTileIterator::Element Element
Element type of shared memory.
Definition: epilogue_base.h:104
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer< Element, StorageShape::kCount > storage
Definition: epilogue_base.h:128
static int const kRow
rows of a matrix
Definition: matrix_shape.h:43
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Top-level include for all CUTLASS numeric types.
Modifies semantics of cutlass::Array<> to provide guaranteed alignment.
Definition: aligned_buffer.h:45
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE pointer data()
Definition: aligned_buffer.h:84
typename WarpTileIterator::Layout Layout
Layout of shared memory allocation.
Definition: epilogue_base.h:110
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: epilogue_base.h:73
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used for rank=1 vectors.
Shape_ Shape
Definition: epilogue_base.h:70
Base class for epilogues defining warp-level.
Definition: epilogue_base.h:67
static CUTLASS_HOST_DEVICE RowMajor packed(MatrixCoord const &extent)
Helper returns a layout to a tightly packed tensor.
Definition: layout/matrix.h:93
Padding_ Padding
Definition: epilogue_base.h:75
CUTLASS_DEVICE EpilogueBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: epilogue_base.h:182
WarpTileIterator_ WarpTileIterator
Definition: epilogue_base.h:74
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue_base.h:81
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_DEVICE TensorRef reference()
Returns a tensor reference to the shared memory buffer.
Definition: epilogue_base.h:142
static int const kN
Definition: include/cutlass/gemm/gemm.h:59