CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tile_iterator_wmma_tensor_op.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #if !defined(__clang__)
32 
33 #include "cutlass/cutlass.h"
34 #include "cutlass/wmma_array.h"
35 #include "cutlass/layout/matrix.h"
37 #include "cutlass/tensor_ref.h"
38 
40 
42 
43 namespace cutlass {
44 namespace epilogue {
45 namespace warp {
46 
48 
50 template <
51  typename WarpShape,
52  typename OperatorShape,
53  typename OperatorFragment,
54  typename Layout
55 >
57 
59 
61 template <
62  typename WarpShape_,
63  typename OperatorShape_,
64  typename OperatorFragment_
65 >
66 class TileIteratorWmmaTensorOp<WarpShape_, OperatorShape_, OperatorFragment_, layout::RowMajor> {
67 public:
68 
69  using WarpShape = WarpShape_;
70  using OperatorShape = OperatorShape_;
71  using OperatorFragment = OperatorFragment_;
73 
74  //
75  // Derived types
76  //
77  using WmmaDataType = typename OperatorFragment::element_type;
78  using Element = typename cutlass::arch::WmmaToCutlassDataType<WmmaDataType>::Type;
81  using Index = typename TensorRef::Index;
82  using LongIndex = typename TensorRef::LongIndex;
83 
84  using Policy = WmmaTensorOpPolicy<WarpShape, OperatorShape, Layout>;
85 
87  using Shape = MatrixShape<
88  Policy::kRowsPerIteration,
89  WarpShape::kN
90  >;
91 
93  using Fragment = WmmaFragmentArray<OperatorFragment, Policy::OperatorCount::kColumn * Policy::kWmmaFragmentsPerAccess>;
94 
95 
97  //using AccumulatorTile = typename Operator::FragmentC;
98 
99 
101  // (Epilogue shared memory padding for WMMA Gemm kernel is set to run optimaly on Turing)
102  using Padding = MatrixShape<
103  0,
104  4 * Policy::kElementsPerAccess
105  >;
106 
107 private:
108 
110  //using AccessType = AlignedArray<Element, Policy::kElementsPerAccess>;
111 
112  //
113  // Data members
114  //
115 
117  TensorRef ref_;
118 
119 
120 public:
121 
125 
126  }
127 
131  TensorRef const &ref,
132  unsigned lane_id
133  ): ref_(ref) {
134  }
135 
139  ref_.add_pointer_offset(pointer_offset);
140  return *this;
141  }
142 
146  ref_.add_coord_offset({tile_offset.row() * OperatorShape::kM, tile_offset.column() * WarpShape::kN});
147  return *this;
148  }
149 
153  add_tile_offset(tile_offset);
154  return *this;
155  }
156 
159  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
160 
161  for(int n=0; n < Policy::OperatorCount::kColumn; n++) {
162 
163  WmmaDataType* ptr = reinterpret_cast<WmmaDataType*> (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset);
164 
165  nvcuda::wmma::store_matrix_sync(
166  ptr,
167  frag[n],
168  ref_.stride()[0],
169  nvcuda::wmma::layout_t::mem_row_major
170  );
171 
172  }
173  }
174 
177  void store(Fragment const &frag) {
178  store_with_pointer_offset(frag, 0);
179  }
180 
183  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
184 
185  for(int n=0; n < Policy::OperatorCount::kColumn; n++) {
186 
187  WmmaDataType* ptr = reinterpret_cast<WmmaDataType*> (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset);
188 
189  nvcuda::wmma::load_matrix_sync(
190  frag[n],
191  ptr,
192  ref_.stride()[0],
193  nvcuda::wmma::layout_t::mem_row_major
194  );
195 
196  }
197  }
198 
201  void load(Fragment &frag) const {
202  load_with_pointer_offset(frag, 0);
203  }
204 };
205 
207 
208 } // namespace warp
209 } // namespace epilogue
210 } // namespace cutlass
211 
213 
214 #endif // !defined(__clang__)
215 
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp & add_pointer_offset(Index pointer_offset)
Adds a pointer offset.
Definition: tile_iterator_wmma_tensor_op.h:138
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp & add_tile_offset(TensorCoord const &tile_offset)
advances in units of whole tiles along the logical coordinate space of the tensor ...
Definition: tile_iterator_wmma_tensor_op.h:145
Definition: aligned_buffer.h:35
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp()
Default constructor.
Definition: tile_iterator_wmma_tensor_op.h:124
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
WmmaTensorOpPolicy< WarpShape, OperatorShape, Layout > Policy
Definition: tile_iterator_wmma_tensor_op.h:84
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp & operator+=(TensorCoord const &tile_offset)
Definition: tile_iterator_wmma_tensor_op.h:152
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
typename OperatorFragment::element_type WmmaDataType
Definition: tile_iterator_wmma_tensor_op.h:77
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object&#39;s stride vector.
Definition: tensor_ref.h:277
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Load.
Definition: tile_iterator_wmma_tensor_op.h:201
#define nullptr
nullptr
Definition: platform.h:144
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store.
Definition: tile_iterator_wmma_tensor_op.h:159
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
CUTLASS_HOST_DEVICE TileIteratorWmmaTensorOp(TensorRef const &ref, unsigned lane_id)
Constructor from TensorRef.
Definition: tile_iterator_wmma_tensor_op.h:130
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Load.
Definition: tile_iterator_wmma_tensor_op.h:183
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
WmmaFragmentArray< OperatorFragment, Policy::OperatorCount::kColumn *Policy::kWmmaFragmentsPerAccess > Fragment
This is the fragment size produced by one access of the iterator.
Definition: tile_iterator_wmma_tensor_op.h:93
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a &#39;column-major&#39; arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
Defines layout functions used by TensorRef and derived classes.
typename cutlass::arch::WmmaToCutlassDataType< WmmaDataType >::Type Element
Data Type of element stored in nvcuda::wmma::frament.
Definition: tile_iterator_wmma_tensor_op.h:78
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Template for reading and writing tiles of accumulators to shared memory.
Definition: tile_iterator_wmma_tensor_op.h:56
CUTLASS_HOST_DEVICE void store(Fragment const &frag)
Store.
Definition: tile_iterator_wmma_tensor_op.h:177
Basic include for CUTLASS.
Definition: matrix_coord.h:39
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168