CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
include/cutlass/gemm/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/cutlass.h"
31 #include "cutlass/coord.h"
32 
33 namespace cutlass {
34 namespace gemm {
35 
37 
39 enum class Operand {
40  kA,
41  kB,
42  kC,
43  kD
44 };
45 
47 
49 template <
51  int M = 1,
53  int N = 1,
55  int K = 1
56 >
57 struct GemmShape {
58  static int const kM = M;
59  static int const kN = N;
60  static int const kK = K;
61 
62  static int const kMN = M * N;
63  static int const kMK = M * K;
64  static int const kKN = N * K;
65  static int const kMNK = M * N * K;
66 
67  static int const kCount = kMNK;
68 
69 
70  //
71  // Static member functions
72  //
73 
76  static Coord<3> toCoord() {
77  return make_Coord(kM, kN, kK);
78  }
79 };
80 
82 
84 template <
86  typename Shape
87 >
89 
91 
94 struct GemmCoord : public Coord<3, int> {
95 
97  typedef int Index;
98 
101 
103  static int const kM = 0;
104 
106  static int const kN = 1;
107 
109  static int const kK = 2;
110 
111  //
112  // Methods
113  //
114 
117  GemmCoord() { }
118 
121  GemmCoord(Coord<3, Index> const &coord): Base(make_Coord(coord[0], coord[1], coord[2])) { }
122 
125  GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { }
126 
129  Index const & m() const { return this->at(kM); }
130 
133  Index & m() { return this->at(kM); }
134 
137  Index const & n() const { return this->at(kN); }
138 
141  Index & n() { return this->at(kN); }
142 
145  Index const & k() const { return this->at(kK); }
146 
149  Index & k() { return this->at(kK); }
150 
153  Coord<3> mnk() const {
154  return make_Coord(m(), n(), k());
155  }
156 
159  Coord<3> knm() const {
160  return make_Coord(k(), n(), m());
161  }
162 
165  Coord<2> nm() const {
166  return make_Coord(n(), m());
167  }
168 
171  Coord<2> mn() const {
172  return make_Coord(m(), n());
173  }
174 
177  Coord<2> mk() const {
178  return make_Coord(m(), k());
179  }
180 
183  Coord<2> km() const {
184  return make_Coord(k(), m());
185  }
186 
189  Coord<2> nk() const {
190  return make_Coord(n(), k());
191  }
192 
195  Coord<2> kn() const {
196  return make_Coord(k(), n());
197  }
198 
199  //
200  // Coord operators
201  //
202 
205  GemmCoord operator+(Base const& b) const {
206  return GemmCoord(Base::operator+(b));
207  }
208 
211  GemmCoord operator-(Base const& b) const {
212  return GemmCoord(Base::operator-(b));
213  }
214 
217  GemmCoord operator*(Base const& b) const {
218  return GemmCoord(Base::operator*(b));
219  }
220 
223  GemmCoord operator/(Base const& b) const {
224  return GemmCoord(Base::operator/(b));
225  }
226 
229  GemmCoord& operator+=(Base const& b) {
230  Base::operator+=(b);
231  return *this;
232  }
233 
236  GemmCoord& operator-=(Base const& b) {
237  Base::operator-=(b);
238  return *this;
239  }
240 
243  GemmCoord& operator*=(Base const& b) {
244  Base::operator*=(b);
245  return *this;
246  }
247 
250  GemmCoord& operator/=(Base const& b) {
251  Base::operator/=(b);
252  return *this;
253  }
254 };
255 
257 
260 struct BatchedGemmCoord : public Coord<4, int> {
261 
263  typedef int Index;
264 
267 
269  static int const kM = 0;
270 
272  static int const kN = 1;
273 
275  static int const kK = 2;
276 
278  static int const kBatch = 3;
279 
280  //
281  // Methods
282  //
283 
287 
290  BatchedGemmCoord(Base const &coord): Base(coord) { }
291 
294  BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { }
295 
298  Index const & m() const { return this->at(kM); }
299 
302  Index & m() { return this->at(kM); }
303 
306  Index const & n() const { return this->at(kN); }
307 
310  Index & n() { return this->at(kN); }
311 
314  Index const & k() const { return this->at(kK); }
315 
318  Index & k() { return this->at(kK); }
319 
322  Index const & batch() const { return this->at(kBatch); }
323 
326  Index & batch() { return this->at(kBatch); }
327 
330  GemmCoord mnk() const {
331  return GemmCoord(m(), n(), k());
332  }
333 
336  Coord<4> mnkb() const {
337  return make_Coord(m(), n(), k(), batch());
338  }
339 
340  //
341  // Coord operators
342  //
343 
346  BatchedGemmCoord operator+(Base const& b) const {
347  return BatchedGemmCoord(Base::operator+(b));
348  }
349 
352  BatchedGemmCoord operator-(Base const& b) const {
353  return BatchedGemmCoord(Base::operator-(b));
354  }
355 
358  BatchedGemmCoord operator*(Base const& b) const {
359  return BatchedGemmCoord(Base::operator*(b));
360  }
361 
364  BatchedGemmCoord operator/(Base const& b) const {
365  return BatchedGemmCoord(Base::operator/(b));
366  }
367 
370  BatchedGemmCoord& operator+=(Base const& b) {
371  Base::operator+=(b);
372  return *this;
373  }
374 
377  BatchedGemmCoord& operator-=(Base const& b) {
378  Base::operator-=(b);
379  return *this;
380  }
381 
384  BatchedGemmCoord& operator*=(Base const& b) {
385  Base::operator*=(b);
386  return *this;
387  }
388 
391  BatchedGemmCoord& operator/=(Base const& b) {
392  Base::operator/=(b);
393  return *this;
394  }
395 };
396 
398 
399 } // namespace gemm
400 } // namespace cutlass
CUTLASS_HOST_DEVICE Coord< 4 > mnkb() const
Obtains a Coord<4> from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:336
Coord< 4, Index > Base
Base type is a Coord of rank=4.
Definition: include/cutlass/gemm/gemm.h:266
CUTLASS_HOST_DEVICE Index & m()
Returns reference to the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:302
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE BatchedGemmCoord operator/(Base const &b) const
Element-wise division.
Definition: include/cutlass/gemm/gemm.h:364
CUTLASS_HOST_DEVICE GemmCoord & operator/=(Base const &b)
In-place division.
Definition: include/cutlass/gemm/gemm.h:250
int Index
Integer-valued index.
Definition: include/cutlass/gemm/gemm.h:97
CUTLASS_HOST_DEVICE GemmCoord(Coord< 3, Index > const &coord)
Constructs from Coord<3> and a batch.
Definition: include/cutlass/gemm/gemm.h:121
CUTLASS_HOST_DEVICE GemmCoord mnk() const
Obtains a GemmCoord from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:330
CUTLASS_HOST_DEVICE GemmCoord operator+(Base const &b) const
Element-wise addition.
Definition: include/cutlass/gemm/gemm.h:205
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
CUTLASS_HOST_DEVICE Index & m()
Returns reference to the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:133
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
CUTLASS_HOST_DEVICE half_t & operator/=(half_t &lhs, half_t const &rhs)
Definition: half.h:684
CUTLASS_HOST_DEVICE BatchedGemmCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: include/cutlass/gemm/gemm.h:352
CUTLASS_HOST_DEVICE Index & batch()
Returns reference to the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:326
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
CUTLASS_HOST_DEVICE Coord< 2 > nm() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:165
CUTLASS_HOST_DEVICE BatchedGemmCoord operator+(Base const &b) const
Element-wise addition.
Definition: include/cutlass/gemm/gemm.h:346
CUTLASS_HOST_DEVICE Index & k()
Returns reference to the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:149
CUTLASS_HOST_DEVICE Index & n()
Returns reference to the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:310
CUTLASS_HOST_DEVICE GemmCoord operator/(Base const &b) const
Element-wise division.
Definition: include/cutlass/gemm/gemm.h:223
CUTLASS_HOST_DEVICE GemmCoord(Index m, Index n, Index k)
Helper to construct from a K, N, M, batch variables.
Definition: include/cutlass/gemm/gemm.h:125
int Index
Integer-valued index.
Definition: include/cutlass/gemm/gemm.h:263
CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)
Definition: half.h:654
CUTLASS_HOST_DEVICE Coord< 2 > nk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:189
CUTLASS_HOST_DEVICE BatchedGemmCoord(Base const &coord)
Constructs from Coord<4>
Definition: include/cutlass/gemm/gemm.h:290
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)
Definition: half.h:664
Coord< 3, Index > Base
Base type is a Coord of rank=4.
Definition: include/cutlass/gemm/gemm.h:100
static CUTLASS_HOST_DEVICE Coord< 3 > toCoord()
Returns a Coord object.
Definition: include/cutlass/gemm/gemm.h:76
CUTLASS_HOST_DEVICE Coord< 2 > km() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:183
Definition: include/cutlass/gemm/gemm.h:260
CUTLASS_HOST_DEVICE Coord< 3 > mnk() const
Obtains a Coord<3> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:153
CUTLASS_HOST_DEVICE GemmCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: include/cutlass/gemm/gemm.h:211
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:322
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator*=(Base const &b)
In-place multiplication.
Definition: include/cutlass/gemm/gemm.h:384
CUTLASS_HOST_DEVICE BatchedGemmCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: include/cutlass/gemm/gemm.h:358
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE Coord< 2 > mk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:177
CUTLASS_HOST_DEVICE BatchedGemmCoord()
Default ctor.
Definition: include/cutlass/gemm/gemm.h:286
CUTLASS_HOST_DEVICE Index & k()
Returns reference to the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:318
Source accumulator.
CUTLASS_HOST_DEVICE Coord< 3 > knm() const
Obtains a Coord<3> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:159
CUTLASS_HOST_DEVICE GemmCoord & operator+=(Base const &b)
In-place addition.
Definition: include/cutlass/gemm/gemm.h:229
CUTLASS_HOST_DEVICE half_t & operator*=(half_t &lhs, half_t const &rhs)
Definition: half.h:674
CUTLASS_HOST_DEVICE GemmCoord()
Default ctor.
Definition: include/cutlass/gemm/gemm.h:117
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator/=(Base const &b)
In-place division.
Definition: include/cutlass/gemm/gemm.h:391
CUTLASS_HOST_DEVICE GemmCoord & operator-=(Base const &b)
In-place subtraction.
Definition: include/cutlass/gemm/gemm.h:236
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator+=(Base const &b)
In-place addition.
Definition: include/cutlass/gemm/gemm.h:370
CUTLASS_HOST_DEVICE Coord< 2 > kn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:195
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
CUTLASS_HOST_DEVICE BatchedGemmCoord(Index m, Index n, Index k, Index b)
Helper to construct from a K, N, M, and batch variables.
Definition: include/cutlass/gemm/gemm.h:294
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
CUTLASS_HOST_DEVICE GemmCoord & operator*=(Base const &b)
In-place multiplication.
Definition: include/cutlass/gemm/gemm.h:243
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator-=(Base const &b)
In-place subtraction.
Definition: include/cutlass/gemm/gemm.h:377
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE Index & n()
Returns reference to the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:141
CUTLASS_HOST_DEVICE GemmCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: include/cutlass/gemm/gemm.h:217
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:298