CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
shared_load_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
33 #pragma once
34 
35 #include "cutlass/cutlass.h"
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/array.h"
38 #include "cutlass/layout/matrix.h"
39 #include "cutlass/matrix_shape.h"
40 #include "cutlass/tensor_ref.h"
41 
43 
45 
46 namespace cutlass {
47 namespace epilogue {
48 namespace threadblock {
49 
51 
56 template <
57  typename ThreadMap_,
58  typename Element_,
59  int MaxAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits<Element_>::value / 8
60 >
62 public:
63  using ThreadMap = ThreadMap_;
64  using Shape = typename ThreadMap::Shape;
65 
66  using Element = Element_;
67 
71 
72  using Index = typename Layout::Index;
73  using LongIndex = typename Layout::LongIndex;
75 
76  static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
77 
78  static int const kMinAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits<Element_>::value / 8;
79 
80  static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment);
81 
82  static int const kThreads = ThreadMap::kThreads;
83 
85  using Fragment = Array<
86  Element,
87  ThreadMap::Iterations::kColumn *
88  ThreadMap::Iterations::kRow *
89  ThreadMap::Iterations::kGroup *
90  ThreadMap::Iterations::kCluster *
91  ThreadMap::kElementsPerAccess>;
92 
94  using AccessType = AlignedArray<
95  Element,
96  ThreadMap::kElementsPerAccess,
97  kAlignment>;
98 
99 private:
100 
101  //
102  // Data members
103  //
104 
106  uint8_t *byte_pointer_;
107 
109  int stride_;
110 
111 public:
112 
113  //
114  // Methods
115  //
116 
118  CUTLASS_DEVICE
120  TensorRef ref,
121  int thread_idx
122  ):
123  byte_pointer_(reinterpret_cast<uint8_t *>(ref.data())),
124  stride_((ref.stride(0) * sizeof_bits<Element>::value) / 8) {
125 
126  TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
127 
128  // Initialize pointer
129  byte_pointer_ +=
130  thread_offset.row() * stride_ +
131  thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;
132 
133  int byte_offset = thread_offset.row() * stride_ +
134  thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;
135  }
136 
139  void add_pointer_offset(LongIndex pointer_offset) {
140  byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
141  }
142 
143  CUTLASS_DEVICE
144  void add_tile_offset(TensorCoord const &offset) {
145  add_pointer_offset(offset.row() * stride_ / (sizeof_bits<Element>::value / 8) + offset.column() * Shape::kColumn);
146  }
147 
149  CUTLASS_DEVICE
150  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
151 
152  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
153 
155  for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
156 
158  for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
159 
161  for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
162 
163  uint8_t const *byte_pointer = byte_pointer_ +
164  row * ThreadMap::Delta::kRow * stride_ +
165  group * ThreadMap::Delta::kGroup* stride_ +
166  cluster * ThreadMap::Delta::kCluster * stride_ +
167  pointer_offset * sizeof_bits<Element>::value / 8;
168 
169  int frag_row_idx =
170  (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
171 
172  AccessType const *memory_pointer = reinterpret_cast<AccessType const *>(byte_pointer);
173 
175  for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
176 
177  int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
178 
179  frag_ptr[frag_idx] =
180  memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess];
181  }
182  }
183  }
184  }
185  }
186 
188  CUTLASS_DEVICE
189  void load(Fragment &frag) {
190 
191  load_with_pointer_offset(frag, 0);
192  }
193 };
194 
196 
197 } // namespace threadblock
198 } // namespace epilogue
199 } // namespace cutlass
200 
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: shared_load_iterator.h:91
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
Loads a fragment from memory.
Definition: shared_load_iterator.h:150
static int const kThreads
Definition: shared_load_iterator.h:82
Aligned array type.
Definition: array.h:511
CUTLASS_DEVICE SharedLoadIterator(TensorRef ref, int thread_idx)
Constructor.
Definition: shared_load_iterator.h:119
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
static int const kMinAlignment
Definition: shared_load_iterator.h:78
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: shared_load_iterator.h:70
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
ThreadMap_ ThreadMap
Definition: shared_load_iterator.h:63
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
static int const kAlignment
Definition: shared_load_iterator.h:80
Defines a Shape template for matrix tiles.
Defines the size of an element in bits.
Definition: numeric_types.h:42
AlignedArray< Element, ThreadMap::kElementsPerAccess, kAlignment > AccessType
Memory access size.
Definition: shared_load_iterator.h:97
typename Layout::Index Index
Definition: shared_load_iterator.h:72
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Metaprogram for determining the mapping of output elements to threads for epilogue tiles...
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment.
Definition: shared_load_iterator.h:189
Element_ Element
Definition: shared_load_iterator.h:66
Defines layout functions used by TensorRef and derived classes.
typename ThreadMap::Shape Shape
Definition: shared_load_iterator.h:64
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: shared_load_iterator.h:139
static int const kElementsPerAccess
Definition: shared_load_iterator.h:76
Definition: shared_load_iterator.h:61
typename Layout::LongIndex LongIndex
Definition: shared_load_iterator.h:73
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_DEVICE void add_tile_offset(TensorCoord const &offset)
Definition: shared_load_iterator.h:144