CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
direct_epilogue_tensor_op.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/array.h"
35 
36 #include "cutlass/gemm/gemm.h"
37 
39 
40 namespace cutlass {
41 namespace epilogue {
42 namespace threadblock {
43 
45 
47 template <
48  typename Shape_,
49  typename Operator_,
50  int PartitionsK,
51  typename Element_,
52  typename OutputOp_,
53  typename ConvertOp_
54 >
56 public:
57 
58  using Shape = Shape_;
59  using Operator = Operator_;
60 
62  using WarpCount = gemm::GemmShape<
63  Shape::kM / Operator::Shape::kM,
64  Shape::kN / Operator::Shape::kN,
65  PartitionsK,
66  >;
67 
68  static_assert(PartitionsK == 1,
69  "Direct epilogue cannot be used with when the threadblock tile is partitioned along the K dimension.");
70 
72  using FragmentC = typename Operator::FragmentC;
73 
75  using Element = Element_;
76 
79 
81  using OutputOp = OutputOp_;
82 
84  using ConvertOp = ConvertOp_;
85 
88 
89 public:
90 
92  struct Params {
93 
94  //
95  // Data members
96  //
97 
100 
101  typename OutputOp::Params output_op;
102  typename ConvertOp::Params convert_op;
103 
104  //
105  // Methods
106  //
107 
111  TensorRef destination_ref_,
112  TensorRef source_ref_,
113  typename OutputOp::Params output_op_,
114  typename ConvertOp::Params convert_op_
115  ):
116  destination_ref(destination_ref_),
117  source_ref(source_ref_),
118  output_op(output_op_),
119  convert_op(convert_op_) {
120 
121  }
122 
126  TensorRef destination_ref_,
127  TensorRef source_ref_,
128  typename OutputOp::Params output_op_
129  ):
130  Params(
131  destination_ref,
132  source_ref,
133  output_op,
134  ConvertOp::Params()
135  ) { }
136  };
137 
139  struct SharedStorage { };
140 
141 private:
142 
145 
146  TensorRef destination_ref_;
147  TensorRef source_ref_;
148 
149  MatrixCoord warp_origin_;
150 
151 public:
152 
154  CUTLASS_DEVICE
156  Params const &params,
157  SharedStorage &shared_storage,
158  int thread_idx,
159  int warp_idx,
160  int lane_idx
161  ):
162  output_op(params.output_op),
163  convert_op(params.convert_op),
164  destination_ref_(params.destination_ref),
165  source_ref_(params.source_ref) {
166 
167 
168  // Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
169  //
170  // _m: the warp's position within the threadblock along the M dimension
171  // _n: the warp's position within the threadblock along the N dimension
172  // _k: the warp's position within the threadblock along the K dimension
173 
174  int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
175  int warp_m = warp_mn % WarpCount::kM;
176  int warp_n = warp_mn / WarpCount::kM;
177 
178  warp_origin_ = MatrixCoord{
179  warp_m * Operator::Shape::kM,
180  warp_n * Operator::Shape::kN
181  };
182 
183  destination_ref_.add_coord_offset(warp_origin_);
184  source_ref_.add_coord_offset(warp_origin_);
185  }
186 
188  CUTLASS_DEVICE
190  gemm::GemmCoord problem_size,
191  gemm::GemmCoord tb_tile_coord,
192  FragmentC const &accumulators) {
193 
194  MatrixCoord thread_origin =
195  MatrixCoord{tb_tile_coord.m() * Shape::kM, tb_tile_coord.n() * Shape::kN} + warp_origin_;
196 
198  using MmaIterations = MatrixShape<
199  Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,
200  Operator::Shape::kN / Operator::Policy::Operator::Shape::kN
201  >;
202 
203  // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire
204  // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements
205  // of that row. The accumulators within one row are assumed to be consecutive.
206  int const kElementsPerAccess = Operator::Policy::Operator::Shape::kN / 4;
207  int const kRowsPerTile = 8;
208  int const kAccumulatorRows = Operator::Policy::Operator::Shape::kM / kRowsPerTile;
209 
211  for (int mma_n = 0; mma_n < MmaIterations::kN; ++mma_n) {
213  for (int mma_m = 0; mma_m < MmaIterations::kM; ++mma_m) {
214 
215  int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
216  (mma_m * MmaIterations::kN + mma_n);
217 
219  for (int row = 0; row < kAccumulatorRows; ++row) {
221  for (int col = 0; col < kElementsPerAccess; ++col) {
222 
223  int accum_m = mma_m * Operator::Policy::Operator::Shape::kM + row * kRowsPerTile;
224  int accum_n = mma_n * Operator::Policy::Operator::Shape::kN + col;
225  int idx = mma_accum_start + row * kElementsPerAccess + col;
226 
227  MatrixCoord accum_coord = MatrixCoord{accum_m, accum_n};
228 
229  MatrixCoord thread_coord = thread_origin + accum_coord;
230 
231  if (thread_coord < MatrixCoord{problem_size.m(), problem_size.n()}) {
232 
233  typename ConvertOp::result_type converted_accum = convert_op(accumulators[idx]);
234 
235  typename OutputOp::result_type output = output_op(converted_accum, source_ref_.at(accum_coord));
236 
237  destination_ref_.at(accum_coord) = output;
238  }
239  }
240  }
241  }
242  }
243  }
244 };
245 
247 
248 } // namespace threadblock
249 } // namespace epilogue
250 } // namespace cutlass
251 
Epilogue operator.
Definition: direct_epilogue_tensor_op.h:55
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_)
Constructs a Params object.
Definition: direct_epilogue_tensor_op.h:125
Parameters structure for host-constructible state.
Definition: direct_epilogue_tensor_op.h:92
Definition: aligned_buffer.h:35
CUTLASS_DEVICE void operator()(gemm::GemmCoord problem_size, gemm::GemmCoord tb_tile_coord, FragmentC const &accumulators)
Streams the result to global memory.
Definition: direct_epilogue_tensor_op.h:189
TensorRef destination_ref
Definition: direct_epilogue_tensor_op.h:98
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
TensorRef< Element, Layout::kRank, Layout > TensorRef
Reference to source and destination tensors.
Definition: direct_epilogue_tensor_op.h:87
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
TensorRef source_ref
Definition: direct_epilogue_tensor_op.h:99
CUTLASS_DEVICE DirectEpilogueTensorOp(Params const &params, SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: direct_epilogue_tensor_op.h:155
OutputOp::Params output_op
Definition: direct_epilogue_tensor_op.h:101
ConvertOp::Params convert_op
Definition: direct_epilogue_tensor_op.h:102
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_, typename ConvertOp::Params convert_op_)
Constructs a Params object.
Definition: direct_epilogue_tensor_op.h:110
Shared storage allocation needed by the epilogue.
Definition: direct_epilogue_tensor_op.h:139
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
typename Operator::FragmentC FragmentC
Accumulator tile is really the warp-scoped tile.
Definition: direct_epilogue_tensor_op.h:72
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Operator_ Operator
Definition: direct_epilogue_tensor_op.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
OutputOp_ OutputOp
Function operator computing final output.
Definition: direct_epilogue_tensor_op.h:81
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
ConvertOp_ ConvertOp
Conversion operator to shared memory.
Definition: direct_epilogue_tensor_op.h:84
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Element_ Element
Data type of output tensor.
Definition: direct_epilogue_tensor_op.h:75
Shape_ Shape
Definition: direct_epilogue_tensor_op.h:58
static int const kN
Definition: include/cutlass/gemm/gemm.h:59