CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_sm72.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33 
35 namespace cutlass {
36 namespace arch {
37 
39 //
40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
41 // wmma native instruction sizes supported for int8_t
42 //
44 template <
45 typename Shape_,
46 typename LayoutA_,
47 typename LayoutB_,
48 typename LayoutC_>
49 struct Wmma<
50  Shape_,
51  int8_t,
52  LayoutA_,
53  int8_t,
54  LayoutB_,
55  int32_t,
56  LayoutC_,
57  cutlass::arch::OpMultiplyAdd
58 > {
59 #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
60  using Shape = Shape_;
61  using ElementA = int8_t;
62  using LayoutA = LayoutA_;
63  using ElementB = int8_t;
64  using LayoutB = LayoutB_;
65  using ElementC = int32_t;
66  using LayoutC = LayoutC_;
67  using Operator = cutlass::arch::OpMultiplyAdd;
68 
69  // check supported wmma shape for the given multiplicand data types
74  "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
75 
76 
77  // Wmma Fragment
78  using FragmentA = nvcuda::wmma::fragment<
79  nvcuda::wmma::matrix_a,
80  Shape::kM,
81  Shape::kN,
82  Shape::kK,
83  typename CutlassToWmmaDataType<ElementA>::Type,
84  typename CutlassToWmmaLayout<LayoutA>::Layout>;
85 
86  using FragmentB = nvcuda::wmma::fragment<
87  nvcuda::wmma::matrix_b,
88  Shape::kM,
89  Shape::kN,
90  Shape::kK,
91  typename CutlassToWmmaDataType<ElementB>::Type,
92  typename CutlassToWmmaLayout<LayoutB>::Layout>;
93 
94  using FragmentC = nvcuda::wmma::fragment<
95  nvcuda::wmma::accumulator,
96  Shape::kM,
97  Shape::kN,
98  Shape::kK,
99  typename CutlassToWmmaDataType<ElementC>::Type>;
100 
102  CUTLASS_DEVICE
103  void operator()(
104  FragmentC &D,
105  FragmentA const &A,
106  FragmentB const &B,
107  FragmentC const &C) const {
108 
109  nvcuda::wmma::mma_sync(D, A, B, C);
110  }
111 
112 #else
113  static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");
114 #endif
115 
116 };
117 
119 //
120 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
121 // wmma native instruction sizes supported for uint8_t
122 //
124 template <
125 typename Shape_,
126 typename LayoutA_,
127 typename LayoutB_,
128 typename LayoutC_>
129 struct Wmma<
130  Shape_,
131  uint8_t,
132  LayoutA_,
133  uint8_t,
134  LayoutB_,
135  int32_t,
136  LayoutC_,
137  cutlass::arch::OpMultiplyAdd
138 > {
139 #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
140  using Shape = Shape_;
141  using ElementA = uint8_t;
142  using LayoutA = LayoutA_;
143  using ElementB = uint8_t;
144  using LayoutB = LayoutB_;
145  using ElementC = int32_t;
146  using LayoutC = LayoutC_;
147  using Operator = cutlass::arch::OpMultiplyAdd;
148 
149  // check supported wmma shape for the given multiplicand data types
154  "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
155 
156  // Wmma Fragment
157  using FragmentA = nvcuda::wmma::fragment<
158  nvcuda::wmma::matrix_a,
159  Shape::kM,
160  Shape::kN,
161  Shape::kK,
162  typename CutlassToWmmaDataType<ElementA>::Type,
163  typename CutlassToWmmaLayout<LayoutA>::Layout>;
164 
165  using FragmentB = nvcuda::wmma::fragment<
166  nvcuda::wmma::matrix_b,
167  Shape::kM,
168  Shape::kN,
169  Shape::kK,
170  typename CutlassToWmmaDataType<ElementB>::Type,
171  typename CutlassToWmmaLayout<LayoutB>::Layout>;
172 
173  using FragmentC = nvcuda::wmma::fragment<
174  nvcuda::wmma::accumulator,
175  Shape::kM,
176  Shape::kN,
177  Shape::kK,
178  typename CutlassToWmmaDataType<ElementC>::Type>;
179 
181  CUTLASS_DEVICE
182  void operator()(
183  FragmentC &D,
184  FragmentA const &A,
185  FragmentB const &B,
186  FragmentC const &C) const {
187 
188  nvcuda::wmma::mma_sync(D, A, B, C);
189  }
190 
191 #else
192  static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");
193 #endif
194 
195 };
196 
197 } // namespace arch
198 } // namespace cutlass
Definition: aligned_buffer.h:35
std::is_same (false specialization)
Definition: platform.h:394
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Defines layout functions used by TensorRef and derived classes.