CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_base.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/aligned_buffer.h"
32 #include "cutlass/arch/memory.h"
33 #include "cutlass/array.h"
34 #include "cutlass/cutlass.h"
35 #include "cutlass/gemm/gemm.h"
36 #include "cutlass/matrix_shape.h"
37 #include "cutlass/numeric_types.h"
39 
40 namespace cutlass {
41 namespace gemm {
42 namespace threadblock {
43 
45 
47 template <
49  typename Operator_,
51  typename SmemPaddingA_,
53  typename SmemPaddingB_,
55  int PartitionsK = 1>
56 struct MmaPolicy {
58  using Operator = Operator_;
59 
61  using SmemPaddingA = SmemPaddingA_;
62 
64  using SmemPaddingB = SmemPaddingB_;
65 
67  static int const kPartitionsK = PartitionsK;
68 };
69 
71 
74 template <
76  typename Shape_,
78  typename Policy_,
80  int Stages,
82  typename Enable = bool>
83 class MmaBase {
84  public:
86  using Shape = Shape_;
87 
89  using Policy = Policy_;
90 
91  //
92  // Dependent types
93  //
94 
96  using Operator = typename Policy::Operator;
97 
100  using WarpGemm = typename Policy::Operator::Shape;
101 
103  using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
104  Shape::kN / WarpGemm::kN,
105  Shape::kK / WarpGemm::kK>;
106 
108  static int const kWarpGemmIterations =
109  (WarpGemm::kK / Operator::Policy::MmaShape::kK);
110 
112  static int const kStages = Stages;
113 
116 
119 
120  //
121  // Nested structs
122  //
123 
126  public:
127  //
128  // Type definitions
129  //
130 
132  using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
133  Shape::kK * kStages +
134  Policy::SmemPaddingA::kColumn>;
135 
137  using ShapeB =
138  MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
139  Shape::kN + Policy::SmemPaddingB::kColumn>;
140 
141  public:
142  //
143  // Data members
144  //
145 
148 
151 
152  public:
153 
154  //
155  // Methods
156  //
157 
159  CUTLASS_DEVICE
160  static typename Operator::LayoutA LayoutA() {
161  return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
162  }
163 
166  static typename Operator::LayoutB LayoutB() {
167  return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
168  }
169 
173  return TensorRefA{operand_A.data(), LayoutA()};
174  }
175 
179  return TensorRefB{operand_B.data(), LayoutB()};
180  }
181  };
182 
183  protected:
184 
185  //
186  // Data members
187  //
188 
190  typename Operator::IteratorA warp_tile_iterator_A_;
191 
193  typename Operator::IteratorB warp_tile_iterator_B_;
194 
195 public:
196 
198  CUTLASS_DEVICE
201  SharedStorage &shared_storage,
203  int thread_idx,
205  int warp_idx,
207  int lane_idx
208  ):
209  warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
210  warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
211 
212  }
213 };
214 
216 
217 } // namespace threadblock
218 } // namespace gemm
219 } // namespace cutlass
220 
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Architecture-specific operators on memory.
AlignedBuffer< typename Operator::ElementB, ShapeB::kCount > operand_B
Buffer for B operand.
Definition: mma_base.h:150
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
typename Policy::Operator::Shape WarpGemm
Definition: mma_base.h:100
Defines common types used for all GEMM-like operators.
Shared storage object needed by threadblock-scoped GEMM.
Definition: mma_base.h:125
Shape_ Shape
Policy describing tuning details.
Definition: mma_base.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Operator_ Operator
Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) ...
Definition: mma_base.h:58
SmemPaddingA_ SmemPaddingA
Padding used for A operand in shared memory.
Definition: mma_base.h:61
Defines a Shape template for matrix tiles.
static CUTLASS_HOST_DEVICE Operator::LayoutB LayoutB()
Returns a layout object for the B matrix.
Definition: mma_base.h:166
Policy object describing MmaTensorOp.
Definition: mma_base.h:56
Definition: tensor_ref.h:146
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Modifies semantics of cutlass::Array<> to provide guaranteed alignment.
Definition: aligned_buffer.h:45
CUTLASS_HOST_DEVICE TensorRefA operand_A_ref()
Returns a TensorRef to the A operand.
Definition: mma_base.h:172
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE pointer data()
Definition: aligned_buffer.h:84
CUTLASS_DEVICE MmaBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_base.h:199
Definition: mma_base.h:83
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_base.h:96
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
AlignedBuffer< typename Operator::ElementA, ShapeA::kCount > operand_A
Buffer for A operand.
Definition: mma_base.h:147
static CUTLASS_DEVICE Operator::LayoutA LayoutA()
Returns a layout object for the A matrix.
Definition: mma_base.h:160
SmemPaddingB_ SmemPaddingB
Padding used for B operand in shared memory.
Definition: mma_base.h:64
static int const kPartitionsK
Number of partitions of K dimension.
Definition: mma_base.h:67
CUTLASS_HOST_DEVICE TensorRefB operand_B_ref()
Returns a TensorRef to the B operand.
Definition: mma_base.h:178
Basic include for CUTLASS.