CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_sm70.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33 
35 namespace cutlass {
36 namespace arch {
37 
38 
40 //
41 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
42 // wmma native instruction sizes supported for half
43 //
45 template <
46 typename Shape_,
47 typename LayoutA_,
48 typename LayoutB_,
49 typename ElementC_,
50 typename LayoutC_>
51 struct Wmma<
52  Shape_,
53  cutlass::half_t,
54  LayoutA_,
56  LayoutB_,
57  ElementC_,
58  LayoutC_,
59  cutlass::arch::OpMultiplyAdd
60 > {
61 
62 #if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED)
63  using Shape = Shape_;
64  using ElementA = cutlass::half_t;
65  using LayoutA = LayoutA_;
66  using ElementB = cutlass::half_t;
67  using LayoutB = LayoutB_;
68  using ElementC = ElementC_;
69  using LayoutC = LayoutC_;
70  using Operator = cutlass::arch::OpMultiplyAdd;
71 
72  // check supported wmma shape for the given multiplicand data types
77  "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
78 
79  // check supported wmma output data type for the given multiplicand data types
82  "Supported of wmma output data type for f16 multiplicands are: f16 and f32");
83 
84  // Wmma Fragment
85  using FragmentA = nvcuda::wmma::fragment<
86  nvcuda::wmma::matrix_a,
87  Shape::kM,
88  Shape::kN,
89  Shape::kK,
90  typename CutlassToWmmaDataType<ElementA>::Type,
91  typename CutlassToWmmaLayout<LayoutA>::Layout>;
92 
93  using FragmentB = nvcuda::wmma::fragment<
94  nvcuda::wmma::matrix_b,
95  Shape::kM,
96  Shape::kN,
97  Shape::kK,
98  typename CutlassToWmmaDataType<ElementB>::Type,
99  typename CutlassToWmmaLayout<LayoutB>::Layout>;
100 
101  using FragmentC = nvcuda::wmma::fragment<
102  nvcuda::wmma::accumulator,
103  Shape::kM,
104  Shape::kN,
105  Shape::kK,
106  typename CutlassToWmmaDataType<ElementC>::Type>;
107 
109  CUTLASS_DEVICE
110  void operator()(
111  FragmentC &D,
112  FragmentA const &A,
113  FragmentB const &B,
114  FragmentC const &C) const {
115 
116  nvcuda::wmma::mma_sync(D, A, B, C);
117  }
118 #else
119  static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond");
120 #endif
121 
122 };
123 
124 } // namespace arch
125 } // namespace cutlass
Definition: aligned_buffer.h:35
std::is_same (false specialization)
Definition: platform.h:394
IEEE half-precision floating-point type.
Definition: half.h:126
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Defines layout functions used by TensorRef and derived classes.