CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tools/util/include/cutlass/util/reference/host/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/functional.h"
35 
36 #include "cutlass/matrix_traits.h"
37 #include "cutlass/tensor_view.h"
38 #include "cutlass/gemm/gemm.h"
39 #include "cutlass/arch/mma.h"
40 
41 namespace cutlass {
42 namespace reference {
43 namespace host {
44 
46 
49 template <
50  typename ElementA,
51  typename LayoutA,
52  typename ElementB,
53  typename LayoutB,
54  typename ElementC,
55  typename LayoutC,
56  typename ScalarType,
57  typename ComputeType,
58  typename InnerProductOp = multiply_add<ComputeType>,
59  typename ConvertOp = NumericConverter<ElementC, ScalarType>
60 >
62  gemm::GemmCoord problem_size,
63  ScalarType alpha,
66  ScalarType beta,
69  ComputeType initial_accum) {
70 
72  LayoutA::kRank == 2 &&
73  LayoutB::kRank == 2 &&
74  LayoutC::kRank == 2, "Tensors must be of rank 2");
75 
76 
77  // Note: batch is ignored.
78  int const M = problem_size.m();
79  int const N = problem_size.n();
80  int const K = problem_size.k();
81 
82  // Blocking necessary to speedup reference implementation
83  int const Mblock = 16;
84  int const Nblock = 16;
85 
86  ConvertOp convert_op;
87  InnerProductOp inner_product_op;
88 
89  for (int row_block = 0; row_block < M; row_block += Mblock) {
90  for (int col_block = 0; col_block < N; col_block += Nblock) {
91 
92  ComputeType accum[Mblock][Nblock];
93 
94  for (int j = 0; j < Nblock; j++) {
95  for (int i = 0; i < Mblock; i++) {
96  accum[i][j] = initial_accum;
97  }
98  }
99 
100  for (int k_block = 0; k_block < K; ++k_block) {
101  for (int j = 0; j < Nblock; j++) {
102  for (int i = 0; i < Mblock; i++) {
103  int row = row_block + i;
104  int col = col_block + j;
105 
106  if (row < M && col < N) {
107  ElementA a = tensor_a.at(MatrixCoord(row, k_block));
108  ElementB b = tensor_b.at(MatrixCoord(k_block, col));
109 
110  accum[i][j] = inner_product_op(ComputeType(a), ComputeType(b), accum[i][j]);
111  }
112  }
113  }
114  }
115 
116  for (int j = 0; j < Nblock; j++) {
117  for (int i = 0; i < Mblock; i++) {
118  int row = row_block + i;
119  int col = col_block + j;
120 
121  MatrixCoord coord = MatrixCoord(row, col);
122 
123  if (row < M && col < N) {
124  tensor_d.at(coord) = convert_op(
125  alpha * ScalarType(accum[i][j]) +
126  beta * ScalarType(tensor_c.at(coord)));
127  }
128  }
129  }
130  }
131  }
132 }
133 
135 
138 template <
139  typename ElementA,
140  typename LayoutA,
141  typename ElementB,
142  typename LayoutB,
143  typename ElementC,
144  typename LayoutC,
145  typename ScalarType,
146  typename ComputeType,
147  typename InnerProductOp = multiply_add<ComputeType>,
148  typename ConvertOp = NumericConverter<ElementC, ScalarType>
149 >
151  gemm::GemmCoord problem_size,
152  ScalarType alpha,
155  ScalarType beta,
157  ComputeType initial_accum) {
158  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
159  ScalarType, ComputeType, InnerProductOp, ConvertOp>(
160  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
161  initial_accum);
162 }
163 
165 
166 template <
167  typename ElementA,
168  typename LayoutA,
169  typename ElementB,
170  typename LayoutB,
171  typename ElementC,
172  typename LayoutC,
173  typename ScalarType,
174  typename ComputeType,
175  typename InnerProductOp = cutlass::arch::OpMultiplyAdd
176 >
177 struct Gemm;
178 
180 
182 template <typename ElementA, typename LayoutA, typename ElementB,
183  typename LayoutB, typename ElementC, typename LayoutC,
184  typename ScalarType, typename ComputeType>
185 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
186  ComputeType, arch::OpMultiplyAdd> {
187 
188  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
190  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
192  ComputeType initial_accum = ComputeType(0)) {
194  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
195  "Tensors must be of rank 2");
196 
197  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
198  ScalarType, ComputeType, multiply_add<ComputeType>>(
199  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
200  }
201 
202  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
204  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
207  ComputeType initial_accum = ComputeType(0)) {
209  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
210  "Tensors must be of rank 2");
211 
212  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
213  ScalarType, ComputeType, multiply_add<ComputeType>>(
214  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
215  }
216 };
217 
219 
221 template <typename ElementA, typename LayoutA, typename ElementB,
222  typename LayoutB, typename ElementC, typename LayoutC,
223  typename ScalarType, typename ComputeType>
224 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
225  ComputeType, arch::OpMultiplyAddSaturate> {
226 
227  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
229  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
231  ComputeType initial_accum = ComputeType(0)) {
233  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
234  "Tensors must be of rank 2");
235 
236  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
237  ScalarType, ComputeType, multiply_add<ComputeType>,
239  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
240  }
241 
242  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
244  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
247  ComputeType initial_accum = ComputeType(0)) {
249  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
250  "Tensors must be of rank 2");
251 
252  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
253  ScalarType, ComputeType, multiply_add<ComputeType>,
255  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
256  }
257 };
258 
260 
262 template <typename ElementA, typename LayoutA, typename ElementB,
263  typename LayoutB, typename ElementC, typename LayoutC,
264  typename ScalarType, typename ComputeType>
265 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
266  ComputeType, arch::OpXorPopc> {
267 
268  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
270  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
272  ComputeType initial_accum = ComputeType(0)) {
274  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
275  "Tensors must be of rank 2");
276 
277  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
278  ScalarType, ComputeType, xor_add<ComputeType>>(
279  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
280  }
281 
282  void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
284  TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
287  ComputeType initial_accum = ComputeType(0)) {
289  LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
290  "Tensors must be of rank 2");
291 
292  compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
293  ScalarType, ComputeType, xor_add<ComputeType>>(
294  problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
295  }
296 };
297 
299 //
300 // Batched GEMM
301 //
303 
305 //
306 // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
307 //
308 template <
309  typename TensorRefCollectionA,
310  typename TensorRefCollectionB,
311  typename TensorRefCollectionC,
312  typename ScalarType,
313  typename AccumulatorType
314 >
316  gemm::GemmCoord problem_size,
317  int batch_count,
318  ScalarType alpha,
319  TensorRefCollectionA const& tensor_a,
320  TensorRefCollectionB const& tensor_b,
321  ScalarType beta,
322  TensorRefCollectionC &tensor_c,
323  AccumulatorType initial_accum) {
324 
325  typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
326  typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
327  typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
328 
329  for (int batch = 0;
330  batch < batch_count;
331  ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
332 
333  Gemm<typename TensorRefCollectionA::Element,
334  typename TensorRefCollectionA::Layout,
335  typename TensorRefCollectionB::Element,
336  typename TensorRefCollectionB::Layout,
337  typename TensorRefCollectionC::Element,
338  typename TensorRefCollectionC::Layout,
339  typename TensorRefCollectionC::Element,
340  typename TensorRefCollectionC::Element>
341  gemm;
342 
343  gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
344  initial_accum);
345  }
346 }
347 
350 //
351 // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
352 //
353 template <
354  typename TensorRefCollectionA,
355  typename TensorRefCollectionB,
356  typename TensorRefCollectionC,
357  typename ScalarType,
358  typename AccumulatorType
359 >
361  gemm::GemmCoord problem_size,
362  int batch_count,
363  ScalarType alpha,
364  TensorRefCollectionA const& tensor_a,
365  TensorRefCollectionB const& tensor_b,
366  ScalarType beta,
367  TensorRefCollectionC &tensor_c) {
368 
369  BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
370 }
371 
373 
374 } // namespace host
375 } // namespace reference
376 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
void BatchedGemm(gemm::GemmCoord problem_size, int batch_count, ScalarType alpha, TensorRefCollectionA const &tensor_a, TensorRefCollectionB const &tensor_b, ScalarType beta, TensorRefCollectionC &tensor_c, AccumulatorType initial_accum)
Computes a batch of GEMMs over a set of matrices of common dimension.
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:315
Definition: aligned_buffer.h:35
Definition: numeric_conversion.h:254
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:177
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Templates exposing architecture support for multiply-add operations.
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:242
Boost-like numeric conversion operator for CUTLASS numeric types.
Top-level include for all CUTLASS numeric types.
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:202
#define static_assert(__e, __m)
Definition: platform.h:153
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:282
void compute_gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum)
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:61
Fused multiply-add.
Definition: functional.h:101
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:188
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:227
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:268
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Definition: matrix_coord.h:39
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...