CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tools/util/include/cutlass/util/reference/device/thread/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 #include "cutlass/matrix_traits.h"
33 #include "cutlass/tensor_view.h"
34 #include "cutlass/gemm/gemm.h"
35 
36 namespace cutlass {
37 namespace reference {
38 namespace device {
39 namespace thread {
40 
42 
44 //
45 // Note, this is a reference implementation. Performance is not expected to approach peak.
46 //
47 template <
48  typename TensorRefA,
49  typename TensorRefB,
50  typename TensorRefC,
51  typename ScalarType,
52  typename AccumulatorType,
53  typename OutputTile,
54  typename InnerProductOp = multiply_add<AccumulatorType>,
56 >
57 struct Gemm {
58 
59  using ElementA = typename TensorRefA::Element;
60  using ElementB = typename TensorRefB::Element;
61  using ElementC = typename TensorRefC::Element;
62 
63  //
64  // Data members
65  //
66 
68  ElementA A_tile[OutputTile::kColumn];
69 
71  ElementB B_tile[OutputTile::kRow];
72 
74  AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow];
75 
76  //
77  // Methods
78  //
79 
82  Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
83 
84  // Clear fetch registers
85  for (int i = 0; i < OutputTile::kColumn; ++i) {
86  A_tile[i] = ElementA(0);
87  }
88 
89  for (int j = 0; j < OutputTile::kColumn; ++j) {
90  B_tile[j] = ElementB(0);
91  }
92 
93  // Clear accumulators
95  for (int j = 0; j < OutputTile::kColumn; ++j) {
97  for (int i = 0; i < OutputTile::kRow; ++i) {
98  accum[j][i] = initial_accum;
99  }
100  }
101  }
102 
106  gemm::GemmCoord problem_size,
107  TensorRefA tensor_a,
108  TensorRefB tensor_b,
109  MatrixCoord output_coord = MatrixCoord()) {
110 
111  InnerProductOp inner_product_op;
112 
113  // Loop over the GEMM K dimension
115  for (int k = 0; k < problem_size.k(); ++k) {
116 
117  // Fetch a slice of the A matrix
119  for (int i = 0; i < OutputTile::kColumn; ++i) {
120  if (output_coord.row() + i < problem_size.m()) {
121  A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k));
122  }
123  }
124 
125  // Fetch a slice of the B matrix
127  for (int j = 0; j < OutputTile::kRow; ++j) {
128  if (output_coord.column() + j < problem_size.n()) {
129  B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j));
130  }
131  }
132 
133  // Compute an accumulated matrix product
135  for (int j = 0; j < OutputTile::kRow; ++j) {
137  for (int i = 0; i < OutputTile::kColumn; ++i) {
138  accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]);
139  }
140  }
141  }
142 
143  return *this;
144  }
145 
149  gemm::GemmCoord problem_size,
150  ScalarType alpha,
151  ScalarType beta,
152  TensorRefC tensor_c,
153  TensorRefC tensor_d,
154  MatrixCoord output_coord = MatrixCoord()) {
155 
156  ConvertOp convert_op;
157 
158  // Update the output tensor
159  for (int j = 0; j < OutputTile::kRow; ++j) {
160  for (int i = 0; i < OutputTile::kColumn; ++i) {
161  MatrixCoord coord = output_coord + MatrixCoord(i, j);
162  if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
163 
164  tensor_d.at(coord) = convert_op(
165  alpha * ScalarType(accum[j][i]) +
166  beta * ScalarType(tensor_c.at(coord))
167  );
168  }
169  }
170  }
171 
172  return *this;
173  }
174 };
175 
177 
178 } // namespace thread
179 } // namespace device
180 } // namespace reference
181 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
Thread-level blocked general matrix product.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:57
Definition: aligned_buffer.h:35
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
typename TensorRefA::Element ElementA
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:59
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
ElementB B_tile[OutputTile::kRow]
Tile for B operand.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:71
Definition: include/cutlass/gemm/gemm.h:94
AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]
Tile for Accumulator.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:74
typename TensorRefB::Element ElementB
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:60
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
CUTLASS_HOST_DEVICE Gemm(AccumulatorType initial_accum=AccumulatorType(0))
Constructor.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:82
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
ElementA A_tile[OutputTile::kColumn]
Tile for A operand.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:68
typename TensorRefC::Element ElementC
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:61
CUTLASS_HOST_DEVICE Gemm & multiply_add(gemm::GemmCoord problem_size, TensorRefA tensor_a, TensorRefB tensor_b, MatrixCoord output_coord=MatrixCoord())
Computes a matrix product.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:105
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
#define CUTLASS_PRAGMA_NO_UNROLL
Definition: cutlass.h:111
CUTLASS_HOST_DEVICE Gemm & epilogue(gemm::GemmCoord problem_size, ScalarType alpha, ScalarType beta, TensorRefC tensor_c, TensorRefC tensor_d, MatrixCoord output_coord=MatrixCoord())
Performs linear scaling of matrix product and updates output tensor.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:148
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Definition: matrix_coord.h:39