CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm/thread/mma_sm50.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/layout/matrix.h"
34 #include "cutlass/arch/mma.h"
35 #include "cutlass/gemm/gemm.h"
37 
39 
40 namespace cutlass {
41 namespace gemm {
42 namespace thread {
43 
45 
47 template <
49  typename Shape_,
51  typename ElementA_,
53  typename LayoutA_,
55  typename ElementB_,
57  typename LayoutB_,
59  typename ElementC_,
61  typename LayoutC_,
63  typename Operator_
64 >
65 struct MmaGeneric {
66 
68  using Shape = Shape_;
69 
71  using ElementA = ElementA_;
72 
74  using LayoutA = LayoutA_;
75 
77  using ElementB = ElementB_;
78 
80  using LayoutB = LayoutB_;
81 
83  using ElementC = ElementC_;
84 
86  using LayoutC = LayoutC_;
87 
89  using Operator = Operator_;
90 
92  using FragmentA = Array<ElementA, Shape::kMK>;
93 
95  using FragmentB = Array<ElementB, Shape::kKN>;
96 
98  using FragmentC = Array<ElementC, Shape::kMN>;
99 
101  using MmaOp = arch::Mma<
103  1,
104  ElementA, LayoutA,
105  ElementB, LayoutB,
106  ElementC, LayoutC,
108 
109  //
110  // Methods
111  //
112 
116  FragmentC & D,
117  FragmentA const & A,
118  FragmentB const & B,
119  FragmentC const & C) {
120 
122  reinterpret_cast<ElementA const *>(&A), LayoutA::packed({Shape::kM, Shape::kK}));
123 
125  reinterpret_cast<ElementB const *>(&B), LayoutB::packed({Shape::kK, Shape::kN}));
126 
128  reinterpret_cast<ElementC *>(&D), LayoutC::packed({ Shape::kM, Shape::kN }));
129 
130  MmaOp mma_op;
131 
132  // Copy accumulators
133  D = C;
134 
135  // Compute matrix product
137  for (int k = 0; k < Shape::kK; ++k) {
138 
140  for (int n = 0; n < Shape::kN; ++n) {
141 
143  for (int m = 0; m < Shape::kM; ++m) {
144 
145  int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m;
146 
147  MatrixCoord mn(m_serpentine, n);
148  MatrixCoord mk(m_serpentine, k);
149  MatrixCoord kn(k, n);
150 
151  Array<ElementC, 1> d;
152  Array<ElementA, 1> a;
153  Array<ElementB, 1> b;
154 
155  d[0] = d_ref.at(mn);
156  a[0] = a_ref.at(mk);
157  b[0] = b_ref.at(kn);
158 
159  mma_op(d, a, b, d);
160 
161  d_ref.at(mn) = d[0];
162  }
163  }
164  }
165  }
166 };
167 
168 
170 
172 template <
174  typename Shape_,
176  typename ElementA_,
178  typename LayoutA_,
180  typename ElementB_,
182  typename LayoutB_,
184  typename ElementC_,
186  typename LayoutC_
187 >
188 struct Mma<
189  Shape_,
190  ElementA_,
191  LayoutA_,
192  ElementB_,
193  LayoutB_,
194  ElementC_,
195  LayoutC_,
196  arch::OpMultiplyAdd,
197  bool> {
198 
200  using Shape = Shape_;
201 
203  using ElementA = ElementA_;
204 
206  using LayoutA = LayoutA_;
207 
209  using ElementB = ElementB_;
210 
212  using LayoutB = LayoutB_;
213 
215  using ElementC = ElementC_;
216 
218  using LayoutC = LayoutC_;
219 
221  using Operator = arch::OpMultiplyAdd;
222 
224  using FragmentA = Array<ElementA, Shape::kMK>;
225 
227  using FragmentB = Array<ElementB, Shape::kKN>;
228 
230  using FragmentC = Array<ElementC, Shape::kMN>;
231 
232  //
233  // Methods
234  //
235 
239  FragmentC & D,
240  FragmentA const & A,
241  FragmentB const & B,
242  FragmentC const & C) {
243 
244  MmaGeneric<
245  Shape,
246  ElementA,
247  LayoutA,
248  ElementB,
249  LayoutB,
250  ElementC,
251  LayoutC,
252  Operator> mma;
253 
254  mma(D, A, B, C);
255  }
256 };
257 
259 
260 } // namespace thread
261 } // namespace gemm
262 } // namespace cutlass
263 
Operator_ Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm50.h:89
Definition: aligned_buffer.h:35
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm50.h:95
Defines a structure containing strides, bounds, and a pointer to tensor data.
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm50.h:98
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm50.h:238
LayoutA_ LayoutA
Layout of A matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:74
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm50.h:230
LayoutC_ LayoutC
Layout of C matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:86
Defines common types used for all GEMM-like operators.
ElementA_ ElementA
Data type of operand A.
Definition: gemm/thread/mma_sm50.h:71
LayoutB_ LayoutB
Layout of B matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:80
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for multiply-add operations.
LayoutB_ LayoutB
Layout of B matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:212
Gemplate that handles all packed matrix layouts.
Definition: gemm/thread/mma_sm50.h:65
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm50.h:224
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm50.h:92
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm50.h:221
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm50.h:68
Templates exposing architecture support for warp-level multiply-add operations.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm50.h:227
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm50.h:200
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm50.h:115
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
Defines layout functions used by TensorRef and derived classes.
LayoutA_ LayoutA
Layout of A matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:206
Matrix multiply-add operation.
Definition: arch/mma.h:92
LayoutC_ LayoutC
Layout of C matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:218
Basic include for CUTLASS.
Definition: matrix_coord.h:39
ElementB_ ElementB
Data type of operand B.
Definition: gemm/thread/mma_sm50.h:77
ElementC_ ElementC
Element type of operand C.
Definition: gemm/thread/mma_sm50.h:83
Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation.
Definition: arch/mma.h:113