CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tools/util/include/cutlass/util/reference/device/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/functional.h"
36 
37 #include "cutlass/matrix_traits.h"
38 #include "cutlass/tensor_view.h"
39 #include "cutlass/gemm/gemm.h"
40 
42 
43 namespace cutlass {
44 namespace reference {
45 namespace device {
46 
48 
56 template <
57  typename ElementA,
58  typename LayoutA,
59  typename ElementB,
60  typename LayoutB,
61  typename ElementC,
62  typename LayoutC,
63  typename ScalarType,
64  typename AccumulatorType,
65  typename InnerProductOp = multiply_add<AccumulatorType>,
66  typename ConvertOp = NumericConverter<ElementC, ScalarType>
67 >
69  gemm::GemmCoord problem_size,
70  ScalarType alpha,
73  ScalarType beta,
76  AccumulatorType initial_accum) {
77 
79  LayoutA::kRank == 2 &&
80  LayoutB::kRank == 2 &&
81  LayoutC::kRank == 2, "Tensors must be of rank 2");
82 
83  // Blocking structure potentially improves performance of reference implementation
84  // with a minor increase in complexity.
85  //
86  // Note, this reference implementation is NOT expected to approach peak performance.
87  using OutputTile = MatrixShape<4, 4>;
88 
89  dim3 block(16, 8);
90 
91  dim3 grid(
92  (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
93  (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
94  );
95 
96  // Launch a GEMM kernel
101  ScalarType,
102  AccumulatorType,
103  OutputTile,
104  InnerProductOp,
105  ConvertOp
106  ><<< grid, block >>>(
107  problem_size,
108  alpha,
109  tensor_a,
110  tensor_b,
111  beta,
112  tensor_c,
113  tensor_d,
114  initial_accum
115  );
116 }
118 
123 template <
124  typename ElementA,
125  typename LayoutA,
126  typename ElementB,
127  typename LayoutB,
128  typename ElementC,
129  typename LayoutC,
130  typename ScalarType,
131  typename AccumulatorType,
132  typename InnerProductOp = multiply_add<AccumulatorType>,
133  typename ConvertOp = NumericConverter<ElementC, ScalarType>
134 >
136  gemm::GemmCoord problem_size,
137  ScalarType alpha,
140  ScalarType beta,
142  AccumulatorType initial_accum) {
143 
144  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
145  ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
146  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
147  initial_accum);
148 }
149 
150 template <
151  typename ElementA,
152  typename LayoutA,
153  typename ElementB,
154  typename LayoutB,
155  typename ElementC,
156  typename LayoutC,
157  typename ScalarType,
158  typename AccumulatorType,
159  typename InnerProductOp = cutlass::arch::OpMultiplyAdd
160 >
161 struct Gemm;
162 
164 
166 template <typename ElementA, typename LayoutA, typename ElementB,
167  typename LayoutB, typename ElementC, typename LayoutC,
168  typename ScalarType, typename AccumulatorType>
169 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
170  ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
171 
172  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
174  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
176  AccumulatorType initial_accum = AccumulatorType(0)) {
177 
179  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
180  "Tensors must be of rank 2");
181 
182  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
183  ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
184  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
185  }
186 
187  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
189  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
192  AccumulatorType initial_accum = AccumulatorType(0)) {
194  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
195  "Tensors must be of rank 2");
196 
197  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
198  ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
199  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
200  }
201 };
202 
204 
206 template <typename ElementA, typename LayoutA, typename ElementB,
207  typename LayoutB, typename ElementC, typename LayoutC,
208  typename ScalarType, typename AccumulatorType>
209 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
210  AccumulatorType, arch::OpMultiplyAddSaturate> {
211 
212  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
214  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
216  AccumulatorType initial_accum = AccumulatorType(0)) {
218  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
219  "Tensors must be of rank 2");
220 
221  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
222  ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
224  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
225  }
226 
227  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
229  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
232  AccumulatorType initial_accum = AccumulatorType(0)) {
234  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
235  "Tensors must be of rank 2");
236 
237  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
238  ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
240  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
241  }
242 };
243 
245 
247 template <typename ElementA, typename LayoutA, typename ElementB,
248  typename LayoutB, typename ElementC, typename LayoutC,
249  typename ScalarType, typename AccumulatorType>
250 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
251  AccumulatorType, arch::OpXorPopc> {
252 
253  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
255  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
257  AccumulatorType initial_accum = AccumulatorType(0)) {
259  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
260  "Tensors must be of rank 2");
261 
262  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
263  ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
264  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
265  }
266 
267  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
269  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
272  AccumulatorType initial_accum = AccumulatorType(0)) {
274  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
275  "Tensors must be of rank 2");
276 
277  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
278  ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
279  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
280  }
281 };
282 
283 
285 //
286 // Batched GEMM
287 //
289 
291 //
292 // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
293 //
294 template <
295  typename TensorRefCollectionA,
296  typename TensorRefCollectionB,
297  typename TensorRefCollectionC,
298  typename ScalarType,
299  typename AccumulatorType,
300  typename InnerProductOp,
301  typename ConvertOp
302 >
304  gemm::GemmCoord problem_size,
305  int batch_count,
306  ScalarType alpha,
307  TensorRefCollectionA const& tensor_a,
308  TensorRefCollectionB const& tensor_b,
309  ScalarType beta,
310  TensorRefCollectionC &tensor_c,
311  AccumulatorType initial_accum) {
312 
314  TensorRefCollectionA::kRank == 2 &&
315  TensorRefCollectionB::kRank == 2 &&
316  TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2");
317 
318  // Blocking structure potentially improves performance of reference implementation
319  // with a minor increase in complexity.
320  //
321  // Note, this reference implementation is NOT expected to approach peak performance.
322  using OutputTile = MatrixShape<4, 4>;
323 
324  dim3 block(16, 8);
325  dim3 grid(
326  (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
327  (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
328  batch_count
329  );
330 
331  // Launch a GEMM kernel
333  TensorRefCollectionA,
334  TensorRefCollectionB,
335  TensorRefCollectionC,
336  ScalarType,
337  AccumulatorType,
338  OutputTile,
339  InnerProductOp,
340  ConvertOp
341  ><<< grid, block >>>(
342  problem_size,
343  alpha,
344  tensor_a,
345  tensor_b,
346  beta,
347  tensor_c,
348  initial_accum
349  );
350 }
351 
354 //
355 // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
356 //
357 template <
358  typename TensorRefCollectionA,
359  typename TensorRefCollectionB,
360  typename TensorRefCollectionC,
361  typename ScalarType,
362  typename AccumulatorType
363 >
365  gemm::GemmCoord problem_size,
366  int batch_count,
367  ScalarType alpha,
368  TensorRefCollectionA const& tensor_a,
369  TensorRefCollectionB const& tensor_b,
370  ScalarType beta,
371  TensorRefCollectionC &tensor_c) {
372 
373  BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
374 }
375 
377 
378 } // namespace device
379 } // namespace reference
380 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:267
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Definition: numeric_conversion.h:254
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
__global__ void BatchedGemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefCollectionA tensor_collection_a, TensorRefCollectionB tensor_collection_b, ScalarType beta, TensorRefCollectionC tensor_collection_c, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:108
Boost-like numeric conversion operator for CUTLASS numeric types.
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:161
void BatchedGemm(gemm::GemmCoord problem_size, int batch_count, ScalarType alpha, TensorRefCollectionA const &tensor_a, TensorRefCollectionB const &tensor_b, ScalarType beta, TensorRefCollectionC &tensor_c, AccumulatorType initial_accum)
Computes a batch of GEMMs over a set of matrices of common dimension.
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:303
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:212
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
Definition: numeric_conversion.h:59
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:227
Fused multiply-add.
Definition: functional.h:101
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
void compute_gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:68
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:187
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:253
Defines properties of matrices used to denote layout and operands to GEMM kernels.
__global__ void Gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefA tensor_a, TensorRefB tensor_b, ScalarType beta, TensorRefC tensor_c, TensorRefC tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:57
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Reference implementation for GEMM in host-side code.
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:172